diff --git a/Cargo.toml b/Cargo.toml index 6e7235f9f..32d7bd990 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tantivy" -version = "0.25.0" +version = "0.26.0" authors = ["Paul Masurel "] license = "MIT" categories = ["database-implementations", "data-structures"] diff --git a/benches/and_or_queries.rs b/benches/and_or_queries.rs index 3e6e0dcd3..805061c18 100644 --- a/benches/and_or_queries.rs +++ b/benches/and_or_queries.rs @@ -20,10 +20,11 @@ use binggan::{black_box, BenchGroup, BenchRunner}; use rand::prelude::*; use rand::rngs::StdRng; use rand::SeedableRng; +use tantivy::collector::sort_key::SortByStaticFastValue; use tantivy::collector::{Collector, Count, TopDocs}; use tantivy::query::{Query, QueryParser}; use tantivy::schema::{Schema, FAST, TEXT}; -use tantivy::{doc, Index, Order, ReloadPolicy, Searcher, SegmentReader}; +use tantivy::{doc, Index, Order, ReloadPolicy, Searcher}; #[derive(Clone)] struct BenchIndex { @@ -159,7 +160,7 @@ fn main() { &mut group, &bench_index, query_str, - TopDocs::with_limit(10), + TopDocs::with_limit(10).order_by_score(), "top10", ); add_bench_task( @@ -173,15 +174,10 @@ fn main() { &mut group, &bench_index, query_str, - TopDocs::with_limit(10).custom_score(move |reader: &SegmentReader| { - let score_col = reader.fast_fields().u64("score").unwrap(); - let score_col2 = reader.fast_fields().u64("score2").unwrap(); - move |doc| { - let score = score_col.first(doc); - let score2 = score_col2.first(doc); - (score, score2) - } - }), + TopDocs::with_limit(10).order_by(( + SortByStaticFastValue::::for_field("score"), + SortByStaticFastValue::::for_field("score2"), + )), "top10_by_2ff", ); } diff --git a/common/src/vint.rs b/common/src/vint.rs index 75f647a9c..22eaa267c 100644 --- a/common/src/vint.rs +++ b/common/src/vint.rs @@ -28,6 +28,7 @@ impl BinarySerializable for VIntU128 { writer.write_all(&buffer) } + #[allow(clippy::unbuffered_bytes)] fn deserialize(reader: &mut R) -> io::Result { #[allow(clippy::unbuffered_bytes)] let mut bytes = reader.bytes(); @@ -196,6 +197,7 @@ impl BinarySerializable for VInt { writer.write_all(&buffer[0..num_bytes]) } + #[allow(clippy::unbuffered_bytes)] fn deserialize(reader: &mut R) -> io::Result { #[allow(clippy::unbuffered_bytes)] let mut bytes = reader.bytes(); diff --git a/examples/basic_search.rs b/examples/basic_search.rs index 1c62c351d..3d01a47fa 100644 --- a/examples/basic_search.rs +++ b/examples/basic_search.rs @@ -208,7 +208,7 @@ fn main() -> tantivy::Result<()> { // is the role of the `TopDocs` collector. // We can now perform our query. - let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?; + let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?; // The actual documents still need to be // retrieved from Tantivy's store. @@ -226,7 +226,7 @@ fn main() -> tantivy::Result<()> { let query = query_parser.parse_query("title:sea^20 body:whale^70")?; let (_score, doc_address) = searcher - .search(&query, &TopDocs::with_limit(1))? + .search(&query, &TopDocs::with_limit(1).order_by_score())? .into_iter() .next() .unwrap(); diff --git a/examples/custom_tokenizer.rs b/examples/custom_tokenizer.rs index 6ec6047c8..1844fb286 100644 --- a/examples/custom_tokenizer.rs +++ b/examples/custom_tokenizer.rs @@ -100,7 +100,7 @@ fn main() -> tantivy::Result<()> { // here we want to get a hit on the 'ken' in Frankenstein let query = query_parser.parse_query("ken")?; - let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?; + let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?; for (_, doc_address) in top_docs { let retrieved_doc: TantivyDocument = searcher.doc(doc_address)?; diff --git a/examples/date_time_field.rs b/examples/date_time_field.rs index c00a22f0d..a5da06c9c 100644 --- a/examples/date_time_field.rs +++ b/examples/date_time_field.rs @@ -50,14 +50,14 @@ fn main() -> tantivy::Result<()> { { // Simple exact search on the date let query = query_parser.parse_query("occurred_at:\"2022-06-22T12:53:50.53Z\"")?; - let count_docs = searcher.search(&*query, &TopDocs::with_limit(5))?; + let count_docs = searcher.search(&*query, &TopDocs::with_limit(5).order_by_score())?; assert_eq!(count_docs.len(), 1); } { // Range query on the date field let query = query_parser .parse_query(r#"occurred_at:[2022-06-22T12:58:00Z TO 2022-06-23T00:00:00Z}"#)?; - let count_docs = searcher.search(&*query, &TopDocs::with_limit(4))?; + let count_docs = searcher.search(&*query, &TopDocs::with_limit(4).order_by_score())?; assert_eq!(count_docs.len(), 1); for (_score, doc_address) in count_docs { let retrieved_doc = searcher.doc::(doc_address)?; diff --git a/examples/deleting_updating_documents.rs b/examples/deleting_updating_documents.rs index ac9859a42..c6e4a04a2 100644 --- a/examples/deleting_updating_documents.rs +++ b/examples/deleting_updating_documents.rs @@ -28,7 +28,7 @@ fn extract_doc_given_isbn( // The second argument is here to tell we don't care about decoding positions, // or term frequencies. let term_query = TermQuery::new(isbn_term.clone(), IndexRecordOption::Basic); - let top_docs = searcher.search(&term_query, &TopDocs::with_limit(1))?; + let top_docs = searcher.search(&term_query, &TopDocs::with_limit(1).order_by_score())?; if let Some((_score, doc_address)) = top_docs.first() { let doc = searcher.doc(*doc_address)?; diff --git a/examples/fuzzy_search.rs b/examples/fuzzy_search.rs index d1aa9f7eb..31d6a1e7d 100644 --- a/examples/fuzzy_search.rs +++ b/examples/fuzzy_search.rs @@ -145,7 +145,7 @@ fn main() -> tantivy::Result<()> { let query = FuzzyTermQuery::new(term, 2, true); let (top_docs, count) = searcher - .search(&query, &(TopDocs::with_limit(5), Count)) + .search(&query, &(TopDocs::with_limit(5).order_by_score(), Count)) .unwrap(); assert_eq!(count, 3); assert_eq!(top_docs.len(), 3); diff --git a/examples/ip_field.rs b/examples/ip_field.rs index 17caf55e8..12ad673f2 100644 --- a/examples/ip_field.rs +++ b/examples/ip_field.rs @@ -69,25 +69,25 @@ fn main() -> tantivy::Result<()> { { // Inclusive range queries let query = query_parser.parse_query("ip:[192.168.0.80 TO 192.168.0.100]")?; - let count_docs = searcher.search(&*query, &TopDocs::with_limit(5))?; + let count_docs = searcher.search(&*query, &TopDocs::with_limit(5).order_by_score())?; assert_eq!(count_docs.len(), 1); } { // Exclusive range queries let query = query_parser.parse_query("ip:{192.168.0.80 TO 192.168.1.100]")?; - let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?; + let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?; assert_eq!(count_docs.len(), 0); } { // Find docs with IP addresses smaller equal 192.168.1.100 let query = query_parser.parse_query("ip:[* TO 192.168.1.100]")?; - let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?; + let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?; assert_eq!(count_docs.len(), 2); } { // Find docs with IP addresses smaller than 192.168.1.100 let query = query_parser.parse_query("ip:[* TO 192.168.1.100}")?; - let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?; + let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?; assert_eq!(count_docs.len(), 2); } diff --git a/examples/json_field.rs b/examples/json_field.rs index 334ea10e4..73399adbc 100644 --- a/examples/json_field.rs +++ b/examples/json_field.rs @@ -59,12 +59,12 @@ fn main() -> tantivy::Result<()> { let query_parser = QueryParser::for_index(&index, vec![event_type, attributes]); { let query = query_parser.parse_query("target:submit-button")?; - let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?; + let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?; assert_eq!(count_docs.len(), 2); } { let query = query_parser.parse_query("target:submit")?; - let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?; + let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?; assert_eq!(count_docs.len(), 2); } { @@ -74,33 +74,33 @@ fn main() -> tantivy::Result<()> { } { let query = query_parser.parse_query("click AND cart.product_id:133")?; - let hits = searcher.search(&*query, &TopDocs::with_limit(2))?; + let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?; assert_eq!(hits.len(), 1); } { // The sub-fields in the json field marked as default field still need to be explicitly // addressed let query = query_parser.parse_query("click AND 133")?; - let hits = searcher.search(&*query, &TopDocs::with_limit(2))?; + let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?; assert_eq!(hits.len(), 0); } { // Default json fields are ignored if they collide with the schema let query = query_parser.parse_query("event_type:holiday-sale")?; - let hits = searcher.search(&*query, &TopDocs::with_limit(2))?; + let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?; assert_eq!(hits.len(), 0); } // # Query via full attribute path { // This only searches in our schema's `event_type` field let query = query_parser.parse_query("event_type:click")?; - let hits = searcher.search(&*query, &TopDocs::with_limit(2))?; + let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?; assert_eq!(hits.len(), 2); } { // Default json fields can still be accessed by full path let query = query_parser.parse_query("attributes.event_type:holiday-sale")?; - let hits = searcher.search(&*query, &TopDocs::with_limit(2))?; + let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?; assert_eq!(hits.len(), 1); } Ok(()) diff --git a/examples/phrase_prefix_search.rs b/examples/phrase_prefix_search.rs index 24239dc7a..e2e1922cb 100644 --- a/examples/phrase_prefix_search.rs +++ b/examples/phrase_prefix_search.rs @@ -63,7 +63,7 @@ fn main() -> Result<()> { // but not "in the Gulf Stream". let query = query_parser.parse_query("\"in the su\"*")?; - let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?; + let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?; let mut titles = top_docs .into_iter() .map(|(_score, doc_address)| { diff --git a/examples/pre_tokenized_text.rs b/examples/pre_tokenized_text.rs index 4cd4e930e..d977da722 100644 --- a/examples/pre_tokenized_text.rs +++ b/examples/pre_tokenized_text.rs @@ -107,7 +107,8 @@ fn main() -> tantivy::Result<()> { IndexRecordOption::Basic, ); - let (top_docs, count) = searcher.search(&query, &(TopDocs::with_limit(2), Count))?; + let (top_docs, count) = + searcher.search(&query, &(TopDocs::with_limit(2).order_by_score(), Count))?; assert_eq!(count, 2); @@ -128,7 +129,8 @@ fn main() -> tantivy::Result<()> { IndexRecordOption::Basic, ); - let (_top_docs, count) = searcher.search(&query, &(TopDocs::with_limit(2), Count))?; + let (_top_docs, count) = + searcher.search(&query, &(TopDocs::with_limit(2).order_by_score(), Count))?; assert_eq!(count, 0); diff --git a/examples/snippet.rs b/examples/snippet.rs index 31bd2c166..04edee82f 100644 --- a/examples/snippet.rs +++ b/examples/snippet.rs @@ -50,7 +50,7 @@ fn main() -> tantivy::Result<()> { let query_parser = QueryParser::for_index(&index, vec![title, body]); let query = query_parser.parse_query("sycamore spring")?; - let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?; + let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?; let snippet_generator = SnippetGenerator::create(&searcher, &*query, body)?; diff --git a/examples/stop_words.rs b/examples/stop_words.rs index 80dab3feb..bc29da47d 100644 --- a/examples/stop_words.rs +++ b/examples/stop_words.rs @@ -102,7 +102,7 @@ fn main() -> tantivy::Result<()> { // stop words are applied on the query as well. // The following will be equivalent to `title:frankenstein` let query = query_parser.parse_query("title:\"the Frankenstein\"")?; - let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?; + let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?; for (score, doc_address) in top_docs { let retrieved_doc: TantivyDocument = searcher.doc(doc_address)?; diff --git a/examples/warmer.rs b/examples/warmer.rs index 1cae9d349..c7543114a 100644 --- a/examples/warmer.rs +++ b/examples/warmer.rs @@ -164,7 +164,7 @@ fn main() -> tantivy::Result<()> { move |doc_id: DocId| Reverse(price[doc_id as usize]) }; - let most_expensive_first = TopDocs::with_limit(10).custom_score(score_by_price); + let most_expensive_first = TopDocs::with_limit(10).order_by(score_by_price); let hits = searcher.search(&query, &most_expensive_first)?; assert_eq!( diff --git a/src/aggregation/metric/top_hits.rs b/src/aggregation/metric/top_hits.rs index 8156a1b66..6a8bdf826 100644 --- a/src/aggregation/metric/top_hits.rs +++ b/src/aggregation/metric/top_hits.rs @@ -16,6 +16,7 @@ use crate::aggregation::intermediate_agg_result::{ }; use crate::aggregation::segment_agg_result::SegmentAggregationCollector; use crate::aggregation::AggregationError; +use crate::collector::sort_key::ReverseComparator; use crate::collector::TopNComputer; use crate::schema::OwnedValue; use crate::{DocAddress, DocId, SegmentOrdinal}; @@ -458,7 +459,7 @@ impl Eq for DocSortValuesAndFields {} #[derive(Clone, Serialize, Deserialize, Debug)] pub struct TopHitsTopNComputer { req: TopHitsAggregationReq, - top_n: TopNComputer, + top_n: TopNComputer, } impl std::cmp::PartialEq for TopHitsTopNComputer { @@ -482,7 +483,7 @@ impl TopHitsTopNComputer { pub(crate) fn merge_fruits(&mut self, other_fruit: Self) -> crate::Result<()> { for doc in other_fruit.top_n.into_vec() { - self.collect(doc.feature, doc.doc); + self.collect(doc.sort_key, doc.doc); } Ok(()) } @@ -494,9 +495,9 @@ impl TopHitsTopNComputer { .into_sorted_vec() .into_iter() .map(|doc| TopHitsVecEntry { - sort: doc.feature.sorts.iter().map(|f| f.value).collect(), + sort: doc.sort_key.sorts.iter().map(|f| f.value).collect(), doc_value_fields: doc - .feature + .sort_key .doc_value_fields .into_iter() .map(|(k, v)| (k, v.into())) @@ -517,7 +518,7 @@ impl TopHitsTopNComputer { pub(crate) struct TopHitsSegmentCollector { segment_ordinal: SegmentOrdinal, accessor_idx: usize, - top_n: TopNComputer, DocAddress, false>, + top_n: TopNComputer, DocAddress, ReverseComparator>, } impl TopHitsSegmentCollector { @@ -544,7 +545,7 @@ impl TopHitsSegmentCollector { let doc_value_fields = req.get_document_field_data(value_accessors, res.doc.doc_id); top_hits_computer.collect( DocSortValuesAndFields { - sorts: res.feature, + sorts: res.sort_key, doc_value_fields, }, res.doc, @@ -645,6 +646,7 @@ mod tests { use crate::aggregation::bucket::tests::get_test_index_from_docs; use crate::aggregation::tests::get_test_index_from_values; use crate::aggregation::AggregationCollector; + use crate::collector::sort_key::ReverseComparator; use crate::collector::ComparableDoc; use crate::query::AllQuery; use crate::schema::OwnedValue; @@ -660,7 +662,7 @@ mod tests { fn collector_with_capacity(capacity: usize) -> super::TopHitsTopNComputer { super::TopHitsTopNComputer { - top_n: super::TopNComputer::new(capacity), + top_n: super::TopNComputer::new_with_comparator(capacity, ReverseComparator), req: Default::default(), } } @@ -774,12 +776,12 @@ mod tests { #[test] fn test_top_hits_collector_single_feature() -> crate::Result<()> { let docs = vec![ - ComparableDoc::<_, _, false> { + ComparableDoc::<_, _> { doc: crate::DocAddress { segment_ord: 0, doc_id: 0, }, - feature: DocSortValuesAndFields { + sort_key: DocSortValuesAndFields { sorts: vec![DocValueAndOrder { value: Some(1), order: Order::Asc, @@ -792,7 +794,7 @@ mod tests { segment_ord: 0, doc_id: 2, }, - feature: DocSortValuesAndFields { + sort_key: DocSortValuesAndFields { sorts: vec![DocValueAndOrder { value: Some(3), order: Order::Asc, @@ -805,7 +807,7 @@ mod tests { segment_ord: 0, doc_id: 1, }, - feature: DocSortValuesAndFields { + sort_key: DocSortValuesAndFields { sorts: vec![DocValueAndOrder { value: Some(5), order: Order::Asc, @@ -817,7 +819,7 @@ mod tests { let mut collector = collector_with_capacity(3); for doc in docs.clone() { - collector.collect(doc.feature, doc.doc); + collector.collect(doc.sort_key, doc.doc); } let res = collector.into_final_result(); @@ -827,15 +829,15 @@ mod tests { super::TopHitsMetricResult { hits: vec![ super::TopHitsVecEntry { - sort: vec![docs[0].feature.sorts[0].value], + sort: vec![docs[0].sort_key.sorts[0].value], doc_value_fields: Default::default(), }, super::TopHitsVecEntry { - sort: vec![docs[1].feature.sorts[0].value], + sort: vec![docs[1].sort_key.sorts[0].value], doc_value_fields: Default::default(), }, super::TopHitsVecEntry { - sort: vec![docs[2].feature.sorts[0].value], + sort: vec![docs[2].sort_key.sorts[0].value], doc_value_fields: Default::default(), }, ] diff --git a/src/collector/custom_score_top_collector.rs b/src/collector/custom_score_top_collector.rs deleted file mode 100644 index 54d42469e..000000000 --- a/src/collector/custom_score_top_collector.rs +++ /dev/null @@ -1,121 +0,0 @@ -use crate::collector::top_collector::{TopCollector, TopSegmentCollector}; -use crate::collector::{Collector, SegmentCollector}; -use crate::{DocAddress, DocId, Score, SegmentReader}; - -pub(crate) struct CustomScoreTopCollector { - custom_scorer: TCustomScorer, - collector: TopCollector, -} - -impl CustomScoreTopCollector -where TScore: Clone + PartialOrd -{ - pub(crate) fn new( - custom_scorer: TCustomScorer, - collector: TopCollector, - ) -> CustomScoreTopCollector { - CustomScoreTopCollector { - custom_scorer, - collector, - } - } -} - -/// A custom segment scorer makes it possible to define any kind of score -/// for a given document belonging to a specific segment. -/// -/// It is the segment local version of the [`CustomScorer`]. -pub trait CustomSegmentScorer: 'static { - /// Computes the score of a specific `doc`. - fn score(&mut self, doc: DocId) -> TScore; -} - -/// `CustomScorer` makes it possible to define any kind of score. -/// -/// The `CustomerScorer` itself does not make much of the computation itself. -/// Instead, it helps constructing `Self::Child` instances that will compute -/// the score at a segment scale. -pub trait CustomScorer: Sync { - /// Type of the associated [`CustomSegmentScorer`]. - type Child: CustomSegmentScorer; - /// Builds a child scorer for a specific segment. The child scorer is associated with - /// a specific segment. - fn segment_scorer(&self, segment_reader: &SegmentReader) -> crate::Result; -} - -impl Collector for CustomScoreTopCollector -where - TCustomScorer: CustomScorer + Send + Sync, - TScore: 'static + PartialOrd + Clone + Send + Sync, -{ - type Fruit = Vec<(TScore, DocAddress)>; - - type Child = CustomScoreTopSegmentCollector; - - fn for_segment( - &self, - segment_local_id: u32, - segment_reader: &SegmentReader, - ) -> crate::Result { - let segment_collector = self.collector.for_segment(segment_local_id, segment_reader); - let segment_scorer = self.custom_scorer.segment_scorer(segment_reader)?; - Ok(CustomScoreTopSegmentCollector { - segment_collector, - segment_scorer, - }) - } - - fn requires_scoring(&self) -> bool { - false - } - - fn merge_fruits(&self, segment_fruits: Vec) -> crate::Result { - self.collector.merge_fruits(segment_fruits) - } -} - -pub struct CustomScoreTopSegmentCollector -where - TScore: 'static + PartialOrd + Clone + Send + Sync + Sized, - T: CustomSegmentScorer, -{ - segment_collector: TopSegmentCollector, - segment_scorer: T, -} - -impl SegmentCollector for CustomScoreTopSegmentCollector -where - TScore: 'static + PartialOrd + Clone + Send + Sync, - T: 'static + CustomSegmentScorer, -{ - type Fruit = Vec<(TScore, DocAddress)>; - - fn collect(&mut self, doc: DocId, _score: Score) { - let score = self.segment_scorer.score(doc); - self.segment_collector.collect(doc, score); - } - - fn harvest(self) -> Vec<(TScore, DocAddress)> { - self.segment_collector.harvest() - } -} - -impl CustomScorer for F -where - F: 'static + Send + Sync + Fn(&SegmentReader) -> T, - T: CustomSegmentScorer, -{ - type Child = T; - - fn segment_scorer(&self, segment_reader: &SegmentReader) -> crate::Result { - Ok((self)(segment_reader)) - } -} - -impl CustomSegmentScorer for F -where F: 'static + FnMut(DocId) -> TScore -{ - fn score(&mut self, doc: DocId) -> TScore { - (self)(doc) - } -} diff --git a/src/collector/filter_collector_wrapper.rs b/src/collector/filter_collector_wrapper.rs index 4e09b027c..b4bada2ff 100644 --- a/src/collector/filter_collector_wrapper.rs +++ b/src/collector/filter_collector_wrapper.rs @@ -12,6 +12,7 @@ use std::marker::PhantomData; use columnar::{BytesColumn, Column, DynamicColumn, HasAssociatedColumnType}; use crate::collector::{Collector, SegmentCollector}; +use crate::schema::Schema; use crate::{DocId, Score, SegmentReader}; /// The `FilterCollector` filters docs using a fast field value and a predicate. @@ -49,13 +50,13 @@ use crate::{DocId, Score, SegmentReader}; /// /// let query_parser = QueryParser::for_index(&index, vec![title]); /// let query = query_parser.parse_query("diary")?; -/// let no_filter_collector = FilterCollector::new("price".to_string(), |value: u64| value > 20_120u64, TopDocs::with_limit(2)); +/// let no_filter_collector = FilterCollector::new("price".to_string(), |value: u64| value > 20_120u64, TopDocs::with_limit(2).order_by_score()); /// let top_docs = searcher.search(&query, &no_filter_collector)?; /// /// assert_eq!(top_docs.len(), 1); /// assert_eq!(top_docs[0].1, DocAddress::new(0, 1)); /// -/// let filter_all_collector: FilterCollector<_, _, u64> = FilterCollector::new("price".to_string(), |value| value < 5u64, TopDocs::with_limit(2)); +/// let filter_all_collector: FilterCollector<_, _, u64> = FilterCollector::new("price".to_string(), |value| value < 5u64, TopDocs::with_limit(2).order_by_score()); /// let filtered_top_docs = searcher.search(&query, &filter_all_collector)?; /// /// assert_eq!(filtered_top_docs.len(), 0); @@ -104,6 +105,11 @@ where type Child = FilterSegmentCollector; + fn check_schema(&self, schema: &Schema) -> crate::Result<()> { + self.collector.check_schema(schema)?; + Ok(()) + } + fn for_segment( &self, segment_local_id: u32, @@ -234,7 +240,7 @@ where /// /// let query_parser = QueryParser::for_index(&index, vec![title]); /// let query = query_parser.parse_query("diary")?; -/// let filter_collector = BytesFilterCollector::new("barcode".to_string(), |bytes: &[u8]| bytes.starts_with(b"01"), TopDocs::with_limit(2)); +/// let filter_collector = BytesFilterCollector::new("barcode".to_string(), |bytes: &[u8]| bytes.starts_with(b"01"), TopDocs::with_limit(2).order_by_score()); /// let top_docs = searcher.search(&query, &filter_collector)?; /// /// assert_eq!(top_docs.len(), 1); @@ -274,6 +280,10 @@ where type Child = BytesFilterSegmentCollector; + fn check_schema(&self, schema: &Schema) -> crate::Result<()> { + self.collector.check_schema(schema) + } + fn for_segment( &self, segment_local_id: u32, diff --git a/src/collector/mod.rs b/src/collector/mod.rs index 6c509fcfb..0f8360d8d 100644 --- a/src/collector/mod.rs +++ b/src/collector/mod.rs @@ -57,7 +57,7 @@ //! # let query_parser = QueryParser::for_index(&index, vec![title]); //! # let query = query_parser.parse_query("diary")?; //! let (doc_count, top_docs): (usize, Vec<(Score, DocAddress)>) = -//! searcher.search(&query, &(Count, TopDocs::with_limit(2)))?; +//! searcher.search(&query, &(Count, TopDocs::with_limit(2).order_by_score()))?; //! # Ok(()) //! # } //! ``` @@ -83,11 +83,15 @@ use downcast_rs::impl_downcast; +use crate::schema::Schema; use crate::{DocId, Score, SegmentOrdinal, SegmentReader}; mod count_collector; pub use self::count_collector::Count; +/// Sort keys +pub mod sort_key; + mod histogram_collector; pub use histogram_collector::HistogramCollector; @@ -95,16 +99,13 @@ mod multi_collector; pub use self::multi_collector::{FruitHandle, MultiCollector, MultiFruit}; mod top_collector; +pub use self::top_collector::ComparableDoc; mod top_score_collector; -pub use self::top_collector::ComparableDoc; pub use self::top_score_collector::{TopDocs, TopNComputer}; -mod custom_score_top_collector; -pub use self::custom_score_top_collector::{CustomScorer, CustomSegmentScorer}; - -mod tweak_score_top_collector; -pub use self::tweak_score_top_collector::{ScoreSegmentTweaker, ScoreTweaker}; +mod sort_key_top_collector; +pub use self::sort_key::{SegmentSortKeyComputer, SortKeyComputer}; mod facet_collector; pub use self::facet_collector::{FacetCollector, FacetCounts}; use crate::query::Weight; @@ -145,6 +146,11 @@ pub trait Collector: Sync + Send { /// Type of the `SegmentCollector` associated with this collector. type Child: SegmentCollector; + /// Returns an error if the schema is not compatible with the collector. + fn check_schema(&self, _schema: &Schema) -> crate::Result<()> { + Ok(()) + } + /// `set_segment` is called before beginning to enumerate /// on this segment. fn for_segment( @@ -170,41 +176,50 @@ pub trait Collector: Sync + Send { segment_ord: u32, reader: &SegmentReader, ) -> crate::Result<::Fruit> { + let with_scoring = self.requires_scoring(); let mut segment_collector = self.for_segment(segment_ord, reader)?; - - match (reader.alive_bitset(), self.requires_scoring()) { - (Some(alive_bitset), true) => { - weight.for_each(reader, &mut |doc, score| { - if alive_bitset.is_alive(doc) { - segment_collector.collect(doc, score); - } - })?; - } - (Some(alive_bitset), false) => { - weight.for_each_no_score(reader, &mut |docs| { - for doc in docs.iter().cloned() { - if alive_bitset.is_alive(doc) { - segment_collector.collect(doc, 0.0); - } - } - })?; - } - (None, true) => { - weight.for_each(reader, &mut |doc, score| { - segment_collector.collect(doc, score); - })?; - } - (None, false) => { - weight.for_each_no_score(reader, &mut |docs| { - segment_collector.collect_block(docs); - })?; - } - } - + default_collect_segment_impl(&mut segment_collector, weight, reader, with_scoring)?; Ok(segment_collector.harvest()) } } +pub(crate) fn default_collect_segment_impl( + segment_collector: &mut TSegmentCollector, + weight: &dyn Weight, + reader: &SegmentReader, + with_scoring: bool, +) -> crate::Result<()> { + match (reader.alive_bitset(), with_scoring) { + (Some(alive_bitset), true) => { + weight.for_each(reader, &mut |doc, score| { + if alive_bitset.is_alive(doc) { + segment_collector.collect(doc, score); + } + })?; + } + (Some(alive_bitset), false) => { + weight.for_each_no_score(reader, &mut |docs| { + for doc in docs.iter().cloned() { + if alive_bitset.is_alive(doc) { + segment_collector.collect(doc, 0.0); + } + } + })?; + } + (None, true) => { + weight.for_each(reader, &mut |doc, score| { + segment_collector.collect(doc, score); + })?; + } + (None, false) => { + weight.for_each_no_score(reader, &mut |docs| { + segment_collector.collect_block(docs); + })?; + } + } + Ok(()) +} + impl SegmentCollector for Option { type Fruit = Option; @@ -230,6 +245,13 @@ impl Collector for Option { type Child = Option<::Child>; + fn check_schema(&self, schema: &Schema) -> crate::Result<()> { + if let Some(underlying_collector) = self { + underlying_collector.check_schema(schema)?; + } + Ok(()) + } + fn for_segment( &self, segment_local_id: SegmentOrdinal, @@ -305,6 +327,12 @@ where type Fruit = (Left::Fruit, Right::Fruit); type Child = (Left::Child, Right::Child); + fn check_schema(&self, schema: &Schema) -> crate::Result<()> { + self.0.check_schema(schema)?; + self.1.check_schema(schema)?; + Ok(()) + } + fn for_segment( &self, segment_local_id: u32, @@ -369,6 +397,13 @@ where type Fruit = (One::Fruit, Two::Fruit, Three::Fruit); type Child = (One::Child, Two::Child, Three::Child); + fn check_schema(&self, schema: &Schema) -> crate::Result<()> { + self.0.check_schema(schema)?; + self.1.check_schema(schema)?; + self.2.check_schema(schema)?; + Ok(()) + } + fn for_segment( &self, segment_local_id: u32, @@ -441,6 +476,14 @@ where type Fruit = (One::Fruit, Two::Fruit, Three::Fruit, Four::Fruit); type Child = (One::Child, Two::Child, Three::Child, Four::Child); + fn check_schema(&self, schema: &Schema) -> crate::Result<()> { + self.0.check_schema(schema)?; + self.1.check_schema(schema)?; + self.2.check_schema(schema)?; + self.3.check_schema(schema)?; + Ok(()) + } + fn for_segment( &self, segment_local_id: u32, diff --git a/src/collector/multi_collector.rs b/src/collector/multi_collector.rs index 7d2196e02..14779c4a4 100644 --- a/src/collector/multi_collector.rs +++ b/src/collector/multi_collector.rs @@ -3,6 +3,7 @@ use std::ops::Deref; use super::{Collector, SegmentCollector}; use crate::collector::Fruit; +use crate::schema::Schema; use crate::{DocId, Score, SegmentOrdinal, SegmentReader, TantivyError}; /// MultiFruit keeps Fruits from every nested Collector @@ -16,6 +17,10 @@ impl Collector for CollectorWrapper { type Fruit = Box; type Child = Box; + fn check_schema(&self, schema: &Schema) -> crate::Result<()> { + self.0.check_schema(schema) + } + fn for_segment( &self, segment_local_id: u32, @@ -147,7 +152,7 @@ impl FruitHandle { /// let searcher = reader.searcher(); /// /// let mut collectors = MultiCollector::new(); -/// let top_docs_handle = collectors.add_collector(TopDocs::with_limit(2)); +/// let top_docs_handle = collectors.add_collector(TopDocs::with_limit(2).order_by_score()); /// let count_handle = collectors.add_collector(Count); /// let query_parser = QueryParser::for_index(&index, vec![title]); /// let query = query_parser.parse_query("diary").unwrap(); @@ -194,6 +199,13 @@ impl Collector for MultiCollector<'_> { type Fruit = MultiFruit; type Child = MultiCollectorChild; + fn check_schema(&self, schema: &Schema) -> crate::Result<()> { + for collector in &self.collector_wrappers { + collector.check_schema(schema)?; + } + Ok(()) + } + fn for_segment( &self, segment_local_id: SegmentOrdinal, @@ -299,7 +311,7 @@ mod tests { let query = TermQuery::new(term, IndexRecordOption::Basic); let mut collectors = MultiCollector::new(); - let topdocs_handler = collectors.add_collector(TopDocs::with_limit(2)); + let topdocs_handler = collectors.add_collector(TopDocs::with_limit(2).order_by_score()); let count_handler = collectors.add_collector(Count); let mut multifruits = searcher.search(&query, &collectors).unwrap(); diff --git a/src/collector/sort_key/mod.rs b/src/collector/sort_key/mod.rs new file mode 100644 index 000000000..a66115633 --- /dev/null +++ b/src/collector/sort_key/mod.rs @@ -0,0 +1,393 @@ +mod order; +mod sort_by_score; +mod sort_by_static_fast_value; +mod sort_by_string; +mod sort_key_computer; + +pub use order::*; +pub use sort_by_score::SortBySimilarityScore; +pub use sort_by_static_fast_value::SortByStaticFastValue; +pub use sort_by_string::SortByString; +pub use sort_key_computer::{SegmentSortKeyComputer, SortKeyComputer}; + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::ops::Range; + + use crate::collector::sort_key::{SortBySimilarityScore, SortByStaticFastValue, SortByString}; + use crate::collector::{ComparableDoc, DocSetCollector, TopDocs}; + use crate::indexer::NoMergePolicy; + use crate::query::{AllQuery, QueryParser}; + use crate::schema::{Schema, FAST, TEXT}; + use crate::{DocAddress, Document, Index, Order, Score, Searcher}; + + fn make_index() -> crate::Result { + let mut schema_builder = Schema::builder(); + let id = schema_builder.add_u64_field("id", FAST); + let city = schema_builder.add_text_field("city", TEXT | FAST); + let catchphrase = schema_builder.add_text_field("catchphrase", TEXT); + let altitude = schema_builder.add_f64_field("altitude", FAST); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + + fn create_segment(index: &Index, docs: Vec) -> crate::Result<()> { + let mut index_writer = index.writer_for_tests()?; + index_writer.set_merge_policy(Box::new(NoMergePolicy)); + for doc in docs { + index_writer.add_document(doc)?; + } + index_writer.commit()?; + Ok(()) + } + + create_segment( + &index, + vec![ + doc!( + id => 0_u64, + city => "austin", + catchphrase => "Hills, Barbeque, Glow", + altitude => 149.0, + ), + doc!( + id => 1_u64, + city => "greenville", + catchphrase => "Grow, Glow, Glow", + altitude => 27.0, + ), + ], + )?; + create_segment( + &index, + vec![doc!( + id => 2_u64, + city => "tokyo", + catchphrase => "Glow, Glow, Glow", + altitude => 40.0, + )], + )?; + create_segment( + &index, + vec![doc!( + id => 3_u64, + catchphrase => "No, No, No", + altitude => 0.0, + )], + )?; + Ok(index) + } + + // NOTE: You cannot determine the SegmentIds that will be generated for Segments + // ahead of time, so DocAddresses must be mapped back to a unique id for each Searcher. + fn id_mapping(searcher: &Searcher) -> HashMap { + searcher + .search(&AllQuery, &DocSetCollector) + .unwrap() + .into_iter() + .map(|doc_address| { + let column = searcher.segment_readers()[doc_address.segment_ord as usize] + .fast_fields() + .u64("id") + .unwrap(); + (doc_address, column.first(doc_address.doc_id).unwrap()) + }) + .collect() + } + + #[test] + fn test_order_by_string() -> crate::Result<()> { + let index = make_index()?; + + #[track_caller] + fn assert_query( + index: &Index, + order: Order, + doc_range: Range, + expected: Vec<(Option, u64)>, + ) -> crate::Result<()> { + let searcher = index.reader()?.searcher(); + let ids = id_mapping(&searcher); + + // Try as primitive. + let top_collector = TopDocs::for_doc_range(doc_range) + .order_by((SortByString::for_field("city"), order)); + let actual = searcher + .search(&AllQuery, &top_collector)? + .into_iter() + .map(|(sort_key_opt, doc)| (sort_key_opt, ids[&doc])) + .collect::>(); + assert_eq!(actual, expected); + Ok(()) + } + + assert_query( + &index, + Order::Asc, + 0..4, + vec![ + (Some("austin".to_owned()), 0), + (Some("greenville".to_owned()), 1), + (Some("tokyo".to_owned()), 2), + (None, 3), + ], + )?; + + assert_query( + &index, + Order::Asc, + 0..3, + vec![ + (Some("austin".to_owned()), 0), + (Some("greenville".to_owned()), 1), + (Some("tokyo".to_owned()), 2), + ], + )?; + + assert_query( + &index, + Order::Asc, + 0..2, + vec![ + (Some("austin".to_owned()), 0), + (Some("greenville".to_owned()), 1), + ], + )?; + + assert_query( + &index, + Order::Asc, + 0..1, + vec![(Some("austin".to_string()), 0)], + )?; + + assert_query( + &index, + Order::Asc, + 1..3, + vec![ + (Some("greenville".to_owned()), 1), + (Some("tokyo".to_owned()), 2), + ], + )?; + + assert_query( + &index, + Order::Desc, + 0..4, + vec![ + (Some("tokyo".to_owned()), 2), + (Some("greenville".to_owned()), 1), + (Some("austin".to_owned()), 0), + (None, 3), + ], + )?; + + assert_query( + &index, + Order::Desc, + 1..3, + vec![ + (Some("greenville".to_owned()), 1), + (Some("austin".to_owned()), 0), + ], + )?; + + assert_query( + &index, + Order::Desc, + 0..1, + vec![(Some("tokyo".to_owned()), 2)], + )?; + + Ok(()) + } + + #[test] + fn test_order_by_f64() -> crate::Result<()> { + let index = make_index()?; + + fn assert_query( + index: &Index, + order: Order, + expected: Vec<(Option, u64)>, + ) -> crate::Result<()> { + let searcher = index.reader()?.searcher(); + let ids = id_mapping(&searcher); + + // Try as primitive. + let top_collector = TopDocs::with_limit(3) + .order_by((SortByStaticFastValue::::for_field("altitude"), order)); + let actual = searcher + .search(&AllQuery, &top_collector)? + .into_iter() + .map(|(altitude_opt, doc)| (altitude_opt, ids[&doc])) + .collect::>(); + assert_eq!(actual, expected); + + Ok(()) + } + + assert_query( + &index, + Order::Asc, + vec![(Some(0.0), 3), (Some(27.0), 1), (Some(40.0), 2)], + )?; + + assert_query( + &index, + Order::Desc, + vec![(Some(149.0), 0), (Some(40.0), 2), (Some(27.0), 1)], + )?; + + Ok(()) + } + + #[test] + fn test_order_by_score() -> crate::Result<()> { + let index = make_index()?; + + fn query(index: &Index, order: Order) -> crate::Result> { + let searcher = index.reader()?.searcher(); + let ids = id_mapping(&searcher); + + let top_collector = TopDocs::with_limit(4).order_by((SortBySimilarityScore, order)); + let field = index.schema().get_field("catchphrase").unwrap(); + let query_parser = QueryParser::for_index(index, vec![field]); + let text_query = query_parser.parse_query("glow")?; + + Ok(searcher + .search(&text_query, &top_collector)? + .into_iter() + .map(|(score, doc)| (score, ids[&doc])) + .collect()) + } + + assert_eq!( + &query(&index, Order::Desc)?, + &[(0.5604893, 2), (0.4904281, 1), (0.35667497, 0),] + ); + + assert_eq!( + &query(&index, Order::Asc)?, + &[(0.35667497, 0), (0.4904281, 1), (0.5604893, 2),] + ); + + Ok(()) + } + + #[test] + fn test_order_by_score_then_string() -> crate::Result<()> { + let index = make_index()?; + + type SortKey = (Score, Option); + + fn query( + index: &Index, + score_order: Order, + city_order: Order, + ) -> crate::Result> { + let searcher = index.reader()?.searcher(); + let ids = id_mapping(&searcher); + + let top_collector = TopDocs::with_limit(4).order_by(( + (SortBySimilarityScore, score_order), + (SortByString::for_field("city"), city_order), + )); + Ok(searcher + .search(&AllQuery, &top_collector)? + .into_iter() + .map(|(f, doc)| (f, ids[&doc])) + .collect()) + } + + assert_eq!( + &query(&index, Order::Asc, Order::Asc)?, + &[ + ((1.0, Some("austin".to_owned())), 0), + ((1.0, Some("greenville".to_owned())), 1), + ((1.0, Some("tokyo".to_owned())), 2), + ((1.0, None), 3), + ] + ); + + assert_eq!( + &query(&index, Order::Asc, Order::Desc)?, + &[ + ((1.0, Some("tokyo".to_owned())), 2), + ((1.0, Some("greenville".to_owned())), 1), + ((1.0, Some("austin".to_owned())), 0), + ((1.0, None), 3), + ] + ); + Ok(()) + } + + use proptest::prelude::*; + + proptest! { + #[test] + fn test_order_by_string_prop( + order in prop_oneof!(Just(Order::Desc), Just(Order::Asc)), + limit in 1..64_usize, + offset in 0..64_usize, + segments_terms in + proptest::collection::vec( + proptest::collection::vec(0..32_u8, 1..32_usize), + 0..8_usize, + ) + ) { + let mut schema_builder = Schema::builder(); + let city = schema_builder.add_text_field("city", TEXT | FAST); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut index_writer = index.writer_for_tests()?; + + // A Vec>, where the outer Vec represents segments, and the inner Vec + // represents terms. + for segment_terms in segments_terms.into_iter() { + for term in segment_terms.into_iter() { + let term = format!("{term:0>3}"); + index_writer.add_document(doc!( + city => term, + ))?; + } + index_writer.commit()?; + } + + let searcher = index.reader()?.searcher(); + let top_n_results = searcher.search(&AllQuery, &TopDocs::with_limit(limit) + .and_offset(offset) + .order_by_string_fast_field("city", order))?; + let all_results = searcher.search(&AllQuery, &DocSetCollector)?.into_iter().map(|doc_address| { + // Get the term for this address. + let column = searcher.segment_readers()[doc_address.segment_ord as usize].fast_fields().str("city").unwrap().unwrap(); + let value = column.term_ords(doc_address.doc_id).next().map(|term_ord| { + let mut city = Vec::new(); + column.dictionary().ord_to_term(term_ord, &mut city).unwrap(); + String::try_from(city).unwrap() + }); + (value, doc_address) + }); + + // Using the TopDocs collector should always be equivalent to sorting, skipping the + // offset, and then taking the limit. + let sorted_docs: Vec<_> = if order.is_desc() { + let mut comparable_docs: Vec> = + all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect(); + comparable_docs.sort(); + comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect() + } else { + let mut comparable_docs: Vec> = + all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect(); + comparable_docs.sort(); + comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect() + }; + let expected_docs = sorted_docs.into_iter().skip(offset).take(limit).collect::>(); + prop_assert_eq!( + expected_docs, + top_n_results + ); + } + } +} diff --git a/src/collector/sort_key/order.rs b/src/collector/sort_key/order.rs new file mode 100644 index 000000000..923d5cb8e --- /dev/null +++ b/src/collector/sort_key/order.rs @@ -0,0 +1,348 @@ +use std::cmp::Ordering; + +use serde::{Deserialize, Serialize}; + +use crate::collector::{SegmentSortKeyComputer, SortKeyComputer}; +use crate::schema::Schema; +use crate::{DocId, Order, Score}; + +/// Comparator trait defining the order in which documents should be ordered. +pub trait Comparator: Send + Sync + std::fmt::Debug + Default { + /// Return the order between two values. + fn compare(&self, lhs: &T, rhs: &T) -> Ordering; +} + +/// With the natural comparator, the top k collector will return +/// the top documents in decreasing order. +#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)] +pub struct NaturalComparator; + +impl Comparator for NaturalComparator { + #[inline(always)] + fn compare(&self, lhs: &T, rhs: &T) -> Ordering { + lhs.partial_cmp(rhs).unwrap() + } +} + +/// Sorts document in reverse order. +/// +/// If the sort key is None, it will considered as the lowest value, and will therefore appear +/// first. +/// +/// The ReverseComparator does not necessarily imply that the sort order is reversed compared +/// to the NaturalComparator. In presence of a tie, both version will retain the higher doc ids. +#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)] +pub struct ReverseComparator; + +impl Comparator for ReverseComparator +where NaturalComparator: Comparator +{ + #[inline(always)] + fn compare(&self, lhs: &T, rhs: &T) -> Ordering { + NaturalComparator.compare(rhs, lhs) + } +} + +/// Sorts document in reverse order, but considers None as having the lowest value. +/// +/// This is usually what is wanted when sorting by a field in an ascending order. +/// For instance, in a e-commerce website, if I sort by price ascending, I most likely want the +/// cheapest items first, and the items without a price at last. +#[derive(Debug, Copy, Clone, Default)] +pub struct ReverseNoneIsLowerComparator; + +impl Comparator> for ReverseNoneIsLowerComparator +where ReverseComparator: Comparator +{ + #[inline(always)] + fn compare(&self, lhs_opt: &Option, rhs_opt: &Option) -> Ordering { + match (lhs_opt, rhs_opt) { + (None, None) => Ordering::Equal, + (None, Some(_)) => Ordering::Less, + (Some(_), None) => Ordering::Greater, + (Some(lhs), Some(rhs)) => ReverseComparator.compare(lhs, rhs), + } + } +} + +impl Comparator for ReverseNoneIsLowerComparator { + #[inline(always)] + fn compare(&self, lhs: &u32, rhs: &u32) -> Ordering { + ReverseComparator.compare(lhs, rhs) + } +} + +impl Comparator for ReverseNoneIsLowerComparator { + #[inline(always)] + fn compare(&self, lhs: &u64, rhs: &u64) -> Ordering { + ReverseComparator.compare(lhs, rhs) + } +} + +impl Comparator for ReverseNoneIsLowerComparator { + #[inline(always)] + fn compare(&self, lhs: &f64, rhs: &f64) -> Ordering { + ReverseComparator.compare(lhs, rhs) + } +} + +impl Comparator for ReverseNoneIsLowerComparator { + #[inline(always)] + fn compare(&self, lhs: &f32, rhs: &f32) -> Ordering { + ReverseComparator.compare(lhs, rhs) + } +} + +impl Comparator for ReverseNoneIsLowerComparator { + #[inline(always)] + fn compare(&self, lhs: &i64, rhs: &i64) -> Ordering { + ReverseComparator.compare(lhs, rhs) + } +} + +impl Comparator for ReverseNoneIsLowerComparator { + #[inline(always)] + fn compare(&self, lhs: &String, rhs: &String) -> Ordering { + ReverseComparator.compare(lhs, rhs) + } +} + +/// An enum representing the different sort orders. +#[derive(Debug, Clone, Copy, Eq, PartialEq, Default)] +pub enum ComparatorEnum { + /// Natural order (See [NaturalComparator]) + #[default] + Natural, + /// Reverse order (See [ReverseComparator]) + Reverse, + /// Reverse order by treating None as the lowest value.(See [ReverseNoneLowerComparator]) + ReverseNoneLower, +} + +impl From for ComparatorEnum { + fn from(order: Order) -> Self { + match order { + Order::Asc => ComparatorEnum::ReverseNoneLower, + Order::Desc => ComparatorEnum::Natural, + } + } +} + +impl Comparator for ComparatorEnum +where + ReverseNoneIsLowerComparator: Comparator, + NaturalComparator: Comparator, + ReverseComparator: Comparator, +{ + #[inline(always)] + fn compare(&self, lhs: &T, rhs: &T) -> Ordering { + match self { + ComparatorEnum::Natural => NaturalComparator.compare(lhs, rhs), + ComparatorEnum::Reverse => ReverseComparator.compare(lhs, rhs), + ComparatorEnum::ReverseNoneLower => ReverseNoneIsLowerComparator.compare(lhs, rhs), + } + } +} + +impl Comparator<(Head, Tail)> + for (LeftComparator, RightComparator) +where + LeftComparator: Comparator, + RightComparator: Comparator, +{ + #[inline(always)] + fn compare(&self, lhs: &(Head, Tail), rhs: &(Head, Tail)) -> Ordering { + self.0 + .compare(&lhs.0, &rhs.0) + .then_with(|| self.1.compare(&lhs.1, &rhs.1)) + } +} + +impl Comparator<(Type1, (Type2, Type3))> + for (Comparator1, Comparator2, Comparator3) +where + Comparator1: Comparator, + Comparator2: Comparator, + Comparator3: Comparator, +{ + #[inline(always)] + fn compare(&self, lhs: &(Type1, (Type2, Type3)), rhs: &(Type1, (Type2, Type3))) -> Ordering { + self.0 + .compare(&lhs.0, &rhs.0) + .then_with(|| self.1.compare(&lhs.1 .0, &rhs.1 .0)) + .then_with(|| self.2.compare(&lhs.1 .1, &rhs.1 .1)) + } +} + +impl Comparator<(Type1, Type2, Type3)> + for (Comparator1, Comparator2, Comparator3) +where + Comparator1: Comparator, + Comparator2: Comparator, + Comparator3: Comparator, +{ + #[inline(always)] + fn compare(&self, lhs: &(Type1, Type2, Type3), rhs: &(Type1, Type2, Type3)) -> Ordering { + self.0 + .compare(&lhs.0, &rhs.0) + .then_with(|| self.1.compare(&lhs.1, &rhs.1)) + .then_with(|| self.2.compare(&lhs.2, &rhs.2)) + } +} + +impl + Comparator<(Type1, (Type2, (Type3, Type4)))> + for (Comparator1, Comparator2, Comparator3, Comparator4) +where + Comparator1: Comparator, + Comparator2: Comparator, + Comparator3: Comparator, + Comparator4: Comparator, +{ + #[inline(always)] + fn compare( + &self, + lhs: &(Type1, (Type2, (Type3, Type4))), + rhs: &(Type1, (Type2, (Type3, Type4))), + ) -> Ordering { + self.0 + .compare(&lhs.0, &rhs.0) + .then_with(|| self.1.compare(&lhs.1 .0, &rhs.1 .0)) + .then_with(|| self.2.compare(&lhs.1 .1 .0, &rhs.1 .1 .0)) + .then_with(|| self.3.compare(&lhs.1 .1 .1, &rhs.1 .1 .1)) + } +} + +impl + Comparator<(Type1, Type2, Type3, Type4)> + for (Comparator1, Comparator2, Comparator3, Comparator4) +where + Comparator1: Comparator, + Comparator2: Comparator, + Comparator3: Comparator, + Comparator4: Comparator, +{ + #[inline(always)] + fn compare( + &self, + lhs: &(Type1, Type2, Type3, Type4), + rhs: &(Type1, Type2, Type3, Type4), + ) -> Ordering { + self.0 + .compare(&lhs.0, &rhs.0) + .then_with(|| self.1.compare(&lhs.1, &rhs.1)) + .then_with(|| self.2.compare(&lhs.2, &rhs.2)) + .then_with(|| self.3.compare(&lhs.3, &rhs.3)) + } +} + +impl SortKeyComputer for (TSortKeyComputer, ComparatorEnum) +where + TSortKeyComputer: SortKeyComputer, + ComparatorEnum: Comparator, + ComparatorEnum: Comparator< + <::Child as SegmentSortKeyComputer>::SegmentSortKey, + >, +{ + type SortKey = TSortKeyComputer::SortKey; + + type Child = SegmentSortKeyComputerWithComparator; + + type Comparator = ComparatorEnum; + + fn check_schema(&self, schema: &Schema) -> crate::Result<()> { + self.0.check_schema(schema) + } + + fn requires_scoring(&self) -> bool { + self.0.requires_scoring() + } + + fn comparator(&self) -> Self::Comparator { + self.1 + } + + fn segment_sort_key_computer( + &self, + segment_reader: &crate::SegmentReader, + ) -> crate::Result { + let child = self.0.segment_sort_key_computer(segment_reader)?; + Ok(SegmentSortKeyComputerWithComparator { + segment_sort_key_computer: child, + comparator: self.comparator(), + }) + } +} + +impl SortKeyComputer for (TSortKeyComputer, Order) +where + TSortKeyComputer: SortKeyComputer, + ComparatorEnum: Comparator, + ComparatorEnum: Comparator< + <::Child as SegmentSortKeyComputer>::SegmentSortKey, + >, +{ + type SortKey = TSortKeyComputer::SortKey; + + type Child = SegmentSortKeyComputerWithComparator; + + type Comparator = ComparatorEnum; + + fn check_schema(&self, schema: &Schema) -> crate::Result<()> { + self.0.check_schema(schema) + } + + fn requires_scoring(&self) -> bool { + self.0.requires_scoring() + } + + fn comparator(&self) -> Self::Comparator { + self.1.into() + } + + fn segment_sort_key_computer( + &self, + segment_reader: &crate::SegmentReader, + ) -> crate::Result { + let child = self.0.segment_sort_key_computer(segment_reader)?; + Ok(SegmentSortKeyComputerWithComparator { + segment_sort_key_computer: child, + comparator: self.comparator(), + }) + } +} + +/// A segment sort key computer with a custom ordering. +pub struct SegmentSortKeyComputerWithComparator { + segment_sort_key_computer: TSegmentSortKeyComputer, + comparator: TComparator, +} + +impl SegmentSortKeyComputer + for SegmentSortKeyComputerWithComparator +where + TSegmentSortKeyComputer: SegmentSortKeyComputer, + TSegmentSortKey: PartialOrd + Clone + 'static + Sync + Send, + TComparator: Comparator + 'static + Sync + Send, +{ + type SortKey = TSegmentSortKeyComputer::SortKey; + type SegmentSortKey = TSegmentSortKey; + + fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey { + self.segment_sort_key_computer.segment_sort_key(doc, score) + } + + #[inline(always)] + fn compare_segment_sort_key( + &self, + left: &Self::SegmentSortKey, + right: &Self::SegmentSortKey, + ) -> Ordering { + self.comparator.compare(left, right) + } + + fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey { + self.segment_sort_key_computer + .convert_segment_sort_key(sort_key) + } +} diff --git a/src/collector/sort_key/sort_by_score.rs b/src/collector/sort_key/sort_by_score.rs new file mode 100644 index 000000000..df8b0dd75 --- /dev/null +++ b/src/collector/sort_key/sort_by_score.rs @@ -0,0 +1,77 @@ +use crate::collector::sort_key::NaturalComparator; +use crate::collector::{SegmentSortKeyComputer, SortKeyComputer, TopNComputer}; +use crate::{DocAddress, DocId, Score}; + +/// Sort by similarity score. +#[derive(Clone, Debug, Copy)] +pub struct SortBySimilarityScore; + +impl SortKeyComputer for SortBySimilarityScore { + type SortKey = Score; + + type Child = SortBySimilarityScore; + + type Comparator = NaturalComparator; + + fn requires_scoring(&self) -> bool { + true + } + + fn segment_sort_key_computer( + &self, + _segment_reader: &crate::SegmentReader, + ) -> crate::Result { + Ok(SortBySimilarityScore) + } + + // Sorting by score is special in that it allows for the Block-Wand optimization. + fn collect_segment_top_k( + &self, + k: usize, + weight: &dyn crate::query::Weight, + reader: &crate::SegmentReader, + segment_ord: u32, + ) -> crate::Result> { + let mut top_n: TopNComputer = + TopNComputer::new_with_comparator(k, self.comparator()); + + if let Some(alive_bitset) = reader.alive_bitset() { + let mut threshold = Score::MIN; + top_n.threshold = Some(threshold); + weight.for_each_pruning(Score::MIN, reader, &mut |doc, score| { + if alive_bitset.is_deleted(doc) { + return threshold; + } + top_n.push(score, doc); + threshold = top_n.threshold.unwrap_or(Score::MIN); + threshold + })?; + } else { + weight.for_each_pruning(Score::MIN, reader, &mut |doc, score| { + top_n.push(score, doc); + top_n.threshold.unwrap_or(Score::MIN) + })?; + } + + Ok(top_n + .into_vec() + .into_iter() + .map(|cid| (cid.sort_key, DocAddress::new(segment_ord, cid.doc))) + .collect()) + } +} + +impl SegmentSortKeyComputer for SortBySimilarityScore { + type SortKey = Score; + + type SegmentSortKey = Score; + + #[inline(always)] + fn segment_sort_key(&mut self, _doc: DocId, score: Score) -> Score { + score + } + + fn convert_segment_sort_key(&self, score: Score) -> Score { + score + } +} diff --git a/src/collector/sort_key/sort_by_static_fast_value.rs b/src/collector/sort_key/sort_by_static_fast_value.rs new file mode 100644 index 000000000..b38b8b034 --- /dev/null +++ b/src/collector/sort_key/sort_by_static_fast_value.rs @@ -0,0 +1,98 @@ +use std::marker::PhantomData; + +use columnar::Column; + +use crate::collector::sort_key::NaturalComparator; +use crate::collector::{SegmentSortKeyComputer, SortKeyComputer}; +use crate::fastfield::{FastFieldNotAvailableError, FastValue}; +use crate::{DocId, Score, SegmentReader}; + +/// Sorts by a fast value (u64, i64, f64, bool). +/// +/// The field must appear explicitly in the schema, with the right type, and declared as +/// a fast field.. +/// +/// If the field is multivalued, only the first value is considered. +/// +/// Documents that do not have this value are still considered. +/// Their sort key will simply be `None`. +#[derive(Debug, Clone)] +pub struct SortByStaticFastValue { + field: String, + typ: PhantomData, +} + +impl SortByStaticFastValue { + /// Creates a new `SortByStaticFastValue` instance for the given field. + pub fn for_field(column_name: impl ToString) -> SortByStaticFastValue { + Self { + field: column_name.to_string(), + typ: PhantomData, + } + } +} + +impl SortKeyComputer for SortByStaticFastValue { + type Child = SortByFastValueSegmentSortKeyComputer; + + type SortKey = Option; + + type Comparator = NaturalComparator; + + fn check_schema(&self, schema: &crate::schema::Schema) -> crate::Result<()> { + // At the segment sort key computer level, we rely on the u64 representation. + // The mapping is monotonic, so it is sufficient to compute our top-K docs. + let field = schema.get_field(&self.field)?; + let field_entry = schema.get_field_entry(field); + if !field_entry.is_fast() { + return Err(crate::TantivyError::SchemaError(format!( + "Field `{}` is not a fast field.", + self.field, + ))); + } + let schema_type = field_entry.field_type().value_type(); + if schema_type != T::to_type() { + return Err(crate::TantivyError::SchemaError(format!( + "Field `{}` is of type {schema_type:?}, not of the type {:?}.", + &self.field, + T::to_type() + ))); + } + Ok(()) + } + + fn segment_sort_key_computer( + &self, + segment_reader: &SegmentReader, + ) -> crate::Result { + let sort_column_opt = segment_reader.fast_fields().u64_lenient(&self.field)?; + let (sort_column, _sort_column_type) = + sort_column_opt.ok_or_else(|| FastFieldNotAvailableError { + field_name: self.field.clone(), + })?; + Ok(SortByFastValueSegmentSortKeyComputer { + sort_column, + typ: PhantomData, + }) + } +} + +pub struct SortByFastValueSegmentSortKeyComputer { + sort_column: Column, + typ: PhantomData, +} + +impl SegmentSortKeyComputer for SortByFastValueSegmentSortKeyComputer { + type SortKey = Option; + + type SegmentSortKey = Option; + + #[inline(always)] + fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Self::SegmentSortKey { + self.sort_column.first(doc) + } + + fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey { + sort_key.map(T::from_u64) + } +} diff --git a/src/collector/sort_key/sort_by_string.rs b/src/collector/sort_key/sort_by_string.rs new file mode 100644 index 000000000..41ef22e9b --- /dev/null +++ b/src/collector/sort_key/sort_by_string.rs @@ -0,0 +1,72 @@ +use columnar::StrColumn; + +use crate::collector::sort_key::NaturalComparator; +use crate::collector::{SegmentSortKeyComputer, SortKeyComputer}; +use crate::termdict::TermOrdinal; +use crate::{DocId, Score}; + +/// Sort by the first value of a string column. +/// +/// The string can be dynamic (coming from a json field) +/// or static (being specificaly defined in the configuration). +/// +/// If the field is multivalued, only the first value is considered. +/// +/// Documents that do not have this value are still considered. +/// Their sort key will simply be `None`. +#[derive(Debug, Clone)] +pub struct SortByString { + column_name: String, +} + +impl SortByString { + /// Creates a new sort by string sort key computer. + pub fn for_field(column_name: impl ToString) -> Self { + SortByString { + column_name: column_name.to_string(), + } + } +} + +impl SortKeyComputer for SortByString { + type SortKey = Option; + + type Child = ByStringColumnSegmentSortKeyComputer; + + type Comparator = NaturalComparator; + + fn segment_sort_key_computer( + &self, + segment_reader: &crate::SegmentReader, + ) -> crate::Result { + let str_column_opt = segment_reader.fast_fields().str(&self.column_name)?; + Ok(ByStringColumnSegmentSortKeyComputer { str_column_opt }) + } +} + +pub struct ByStringColumnSegmentSortKeyComputer { + str_column_opt: Option, +} + +impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer { + type SortKey = Option; + + type SegmentSortKey = Option; + + #[inline(always)] + fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Option { + let str_column = self.str_column_opt.as_ref()?; + str_column.ords().first(doc) + } + + fn convert_segment_sort_key(&self, term_ord_opt: Option) -> Option { + let term_ord = term_ord_opt?; + let str_column = self.str_column_opt.as_ref()?; + let mut bytes = Vec::new(); + str_column + .dictionary() + .ord_to_term(term_ord, &mut bytes) + .ok()?; + String::try_from(bytes).ok() + } +} diff --git a/src/collector/sort_key/sort_key_computer.rs b/src/collector/sort_key/sort_key_computer.rs new file mode 100644 index 000000000..d56fa7cd0 --- /dev/null +++ b/src/collector/sort_key/sort_key_computer.rs @@ -0,0 +1,631 @@ +use std::cmp::Ordering; + +use crate::collector::sort_key::{Comparator, NaturalComparator}; +use crate::collector::sort_key_top_collector::TopBySortKeySegmentCollector; +use crate::collector::{default_collect_segment_impl, SegmentCollector as _, TopNComputer}; +use crate::schema::Schema; +use crate::{DocAddress, DocId, Result, Score, SegmentReader}; + +/// A `SegmentSortKeyComputer` makes it possible to modify the default score +/// for a given document belonging to a specific segment. +/// +/// It is the segment local version of the [`SortKeyComputer`]. +pub trait SegmentSortKeyComputer: 'static { + /// The final score being emitted. + type SortKey: 'static + PartialOrd + Send + Sync + Clone; + + /// Sort key used by at the segment level by the `SegmentSortKeyComputer`. + /// + /// It is typically small like a `u64`, and is meant to be converted + /// to the final score at the end of the collection of the segment. + type SegmentSortKey: 'static + PartialOrd + Clone + Send + Sync + Clone; + + /// Computes the sort key for the given document and score. + fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey; + + /// Computes the sort key and pushes the document in a TopN Computer. + /// + /// When using a tuple as the sorting key, the sort key is evaluated in a lazy manner. + #[inline(always)] + fn compute_sort_key_and_collect>( + &mut self, + doc: DocId, + score: Score, + top_n_computer: &mut TopNComputer, + ) { + let sort_key = self.segment_sort_key(doc, score); + top_n_computer.push(sort_key, doc); + } + + /// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on + /// its ordering. + /// + /// This method must be consistent with the `SortKey` ordering. + #[inline(always)] + fn compare_segment_sort_key( + &self, + left: &Self::SegmentSortKey, + right: &Self::SegmentSortKey, + ) -> Ordering { + NaturalComparator.compare(left, right) + } + + /// Implementing this method makes it possible to avoid computing + /// a sort_key entirely if we can assess that it won't pass a threshold + /// with a partial computation. + /// + /// This is currently used for lexicographic sorting. + fn accept_sort_key_lazy( + &mut self, + doc_id: DocId, + score: Score, + threshold: &Self::SegmentSortKey, + ) -> Option<(Ordering, Self::SegmentSortKey)> { + let sort_key = self.segment_sort_key(doc_id, score); + let cmp = self.compare_segment_sort_key(&sort_key, threshold); + if cmp == Ordering::Less { + None + } else { + Some((cmp, sort_key)) + } + } + + /// Convert a segment level sort key into the global sort key. + fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey; +} + +/// `SortKeyComputer` defines the sort key to be used by a TopK Collector. +/// +/// The `SortKeyComputer` itself does not make much of the computation itself. +/// Instead, it helps constructing `Self::Child` instances that will compute +/// the sort key at a segment scale. +pub trait SortKeyComputer: Sync { + /// The sort key type. + type SortKey: 'static + Send + Sync + PartialOrd + Clone + std::fmt::Debug; + /// Type of the associated [`SegmentSortKeyComputer`]. + type Child: SegmentSortKeyComputer; + /// Comparator type. + type Comparator: Comparator + + Comparator<::SegmentSortKey> + + 'static; + + /// Checks whether the schema is compatible with the sort key computer. + fn check_schema(&self, _schema: &Schema) -> crate::Result<()> { + Ok(()) + } + + /// Returns the sort key comparator. + fn comparator(&self) -> Self::Comparator { + Self::Comparator::default() + } + + /// Indicates whether the sort key actually uses the similarity score (by default BM25). + /// If set to false, the similary score might not be computed (as an optimization), + /// and the score fed in the segment sort key computer could take any value. + fn requires_scoring(&self) -> bool { + false + } + + /// Sorting by score has a overriding implementation for BM25 scores, using Block-WAND. + fn collect_segment_top_k( + &self, + k: usize, + weight: &dyn crate::query::Weight, + reader: &crate::SegmentReader, + segment_ord: u32, + ) -> crate::Result> { + let with_scoring = self.requires_scoring(); + let segment_sort_key_computer = self.segment_sort_key_computer(reader)?; + let topn_computer = TopNComputer::new_with_comparator(k, self.comparator()); + let mut segment_top_key_collector = TopBySortKeySegmentCollector { + topn_computer, + segment_ord, + segment_sort_key_computer, + }; + default_collect_segment_impl(&mut segment_top_key_collector, weight, reader, with_scoring)?; + Ok(segment_top_key_collector.harvest()) + } + + /// Builds a child sort key computer for a specific segment. + fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result; +} + +impl SortKeyComputer + for (HeadSortKeyComputer, TailSortKeyComputer) +where + HeadSortKeyComputer: SortKeyComputer, + TailSortKeyComputer: SortKeyComputer, +{ + type SortKey = ( + ::SortKey, + ::SortKey, + ); + type Child = (HeadSortKeyComputer::Child, TailSortKeyComputer::Child); + + type Comparator = ( + HeadSortKeyComputer::Comparator, + TailSortKeyComputer::Comparator, + ); + + fn comparator(&self) -> Self::Comparator { + (self.0.comparator(), self.1.comparator()) + } + + fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result { + Ok(( + self.0.segment_sort_key_computer(segment_reader)?, + self.1.segment_sort_key_computer(segment_reader)?, + )) + } + + /// Checks whether the schema is compatible with the sort key computer. + fn check_schema(&self, schema: &Schema) -> crate::Result<()> { + self.0.check_schema(schema)?; + self.1.check_schema(schema)?; + Ok(()) + } + + /// Indicates whether the sort key actually uses the similarity score (by default BM25). + /// If set to false, the similary score might not be computed (as an optimization), + /// and the score fed in the segment sort key computer could take any value. + fn requires_scoring(&self) -> bool { + self.0.requires_scoring() || self.1.requires_scoring() + } +} + +impl SegmentSortKeyComputer + for (HeadSegmentSortKeyComputer, TailSegmentSortKeyComputer) +where + HeadSegmentSortKeyComputer: SegmentSortKeyComputer, + TailSegmentSortKeyComputer: SegmentSortKeyComputer, +{ + type SortKey = ( + HeadSegmentSortKeyComputer::SortKey, + TailSegmentSortKeyComputer::SortKey, + ); + type SegmentSortKey = ( + HeadSegmentSortKeyComputer::SegmentSortKey, + TailSegmentSortKeyComputer::SegmentSortKey, + ); + + /// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on + /// its ordering. + /// + /// By default, it uses the natural ordering. + #[inline] + fn compare_segment_sort_key( + &self, + left: &Self::SegmentSortKey, + right: &Self::SegmentSortKey, + ) -> Ordering { + self.0 + .compare_segment_sort_key(&left.0, &right.0) + .then_with(|| self.1.compare_segment_sort_key(&left.1, &right.1)) + } + + #[inline(always)] + fn compute_sort_key_and_collect>( + &mut self, + doc: DocId, + score: Score, + top_n_computer: &mut TopNComputer, + ) { + let sort_key: Self::SegmentSortKey; + if let Some(threshold) = &top_n_computer.threshold { + if let Some((_cmp, lazy_sort_key)) = self.accept_sort_key_lazy(doc, score, threshold) { + sort_key = lazy_sort_key; + } else { + return; + } + } else { + sort_key = self.segment_sort_key(doc, score); + }; + top_n_computer.append_doc(doc, sort_key); + } + + #[inline(always)] + fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey { + let head_sort_key = self.0.segment_sort_key(doc, score); + let tail_sort_key = self.1.segment_sort_key(doc, score); + (head_sort_key, tail_sort_key) + } + + fn accept_sort_key_lazy( + &mut self, + doc_id: DocId, + score: Score, + threshold: &Self::SegmentSortKey, + ) -> Option<(Ordering, Self::SegmentSortKey)> { + let (head_threshold, tail_threshold) = threshold; + let (head_cmp, head_sort_key) = + self.0.accept_sort_key_lazy(doc_id, score, head_threshold)?; + if head_cmp == Ordering::Equal { + let (tail_cmp, tail_sort_key) = + self.1.accept_sort_key_lazy(doc_id, score, tail_threshold)?; + Some((tail_cmp, (head_sort_key, tail_sort_key))) + } else { + let tail_sort_key = self.1.segment_sort_key(doc_id, score); + Some((head_cmp, (head_sort_key, tail_sort_key))) + } + } + + fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey { + let (head_sort_key, tail_sort_key) = sort_key; + ( + self.0.convert_segment_sort_key(head_sort_key), + self.1.convert_segment_sort_key(tail_sort_key), + ) + } +} + +/// This struct is used as an adapter to take a sort key computer and map its score to another +/// new sort key. +pub struct MappedSegmentSortKeyComputer { + sort_key_computer: T, + map: fn(PreviousSortKey) -> NewSortKey, +} + +impl SegmentSortKeyComputer + for MappedSegmentSortKeyComputer +where + T: SegmentSortKeyComputer, + PreviousScore: 'static + Clone + Send + Sync + PartialOrd, + NewScore: 'static + Clone + Send + Sync + PartialOrd, +{ + type SortKey = NewScore; + type SegmentSortKey = T::SegmentSortKey; + + fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey { + self.sort_key_computer.segment_sort_key(doc, score) + } + + fn accept_sort_key_lazy( + &mut self, + doc_id: DocId, + score: Score, + threshold: &Self::SegmentSortKey, + ) -> Option<(Ordering, Self::SegmentSortKey)> { + self.sort_key_computer + .accept_sort_key_lazy(doc_id, score, threshold) + } + + #[inline(always)] + fn compute_sort_key_and_collect>( + &mut self, + doc: DocId, + score: Score, + top_n_computer: &mut TopNComputer, + ) { + self.sort_key_computer + .compute_sort_key_and_collect(doc, score, top_n_computer); + } + + fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> Self::SortKey { + (self.map)( + self.sort_key_computer + .convert_segment_sort_key(segment_sort_key), + ) + } +} + +// We then re-use our (head, tail) implement and our mapper by seeing mapping any tuple (a, b, c, +// ...) as the chain (a, (b, (c, ...))) + +impl SortKeyComputer + for (SortKeyComputer1, SortKeyComputer2, SortKeyComputer3) +where + SortKeyComputer1: SortKeyComputer, + SortKeyComputer2: SortKeyComputer, + SortKeyComputer3: SortKeyComputer, +{ + type SortKey = ( + SortKeyComputer1::SortKey, + SortKeyComputer2::SortKey, + SortKeyComputer3::SortKey, + ); + type Child = MappedSegmentSortKeyComputer< + <(SortKeyComputer1, (SortKeyComputer2, SortKeyComputer3)) as SortKeyComputer>::Child, + ( + SortKeyComputer1::SortKey, + (SortKeyComputer2::SortKey, SortKeyComputer3::SortKey), + ), + Self::SortKey, + >; + + type Comparator = ( + SortKeyComputer1::Comparator, + SortKeyComputer2::Comparator, + SortKeyComputer3::Comparator, + ); + + fn comparator(&self) -> Self::Comparator { + ( + self.0.comparator(), + self.1.comparator(), + self.2.comparator(), + ) + } + + fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result { + let sort_key_computer1 = self.0.segment_sort_key_computer(segment_reader)?; + let sort_key_computer2 = self.1.segment_sort_key_computer(segment_reader)?; + let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?; + let map = |(sort_key1, (sort_key2, sort_key3))| (sort_key1, sort_key2, sort_key3); + Ok(MappedSegmentSortKeyComputer { + sort_key_computer: (sort_key_computer1, (sort_key_computer2, sort_key_computer3)), + map, + }) + } + + fn check_schema(&self, schema: &Schema) -> crate::Result<()> { + self.0.check_schema(schema)?; + self.1.check_schema(schema)?; + self.2.check_schema(schema)?; + Ok(()) + } + + fn requires_scoring(&self) -> bool { + self.0.requires_scoring() || self.1.requires_scoring() || self.2.requires_scoring() + } +} + +impl SortKeyComputer + for ( + SortKeyComputer1, + SortKeyComputer2, + SortKeyComputer3, + SortKeyComputer4, + ) +where + SortKeyComputer1: SortKeyComputer, + SortKeyComputer2: SortKeyComputer, + SortKeyComputer3: SortKeyComputer, + SortKeyComputer4: SortKeyComputer, +{ + type Child = MappedSegmentSortKeyComputer< + <( + SortKeyComputer1, + (SortKeyComputer2, (SortKeyComputer3, SortKeyComputer4)), + ) as SortKeyComputer>::Child, + ( + SortKeyComputer1::SortKey, + ( + SortKeyComputer2::SortKey, + (SortKeyComputer3::SortKey, SortKeyComputer4::SortKey), + ), + ), + Self::SortKey, + >; + type SortKey = ( + SortKeyComputer1::SortKey, + SortKeyComputer2::SortKey, + SortKeyComputer3::SortKey, + SortKeyComputer4::SortKey, + ); + type Comparator = ( + SortKeyComputer1::Comparator, + SortKeyComputer2::Comparator, + SortKeyComputer3::Comparator, + SortKeyComputer4::Comparator, + ); + + fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result { + let sort_key_computer1 = self.0.segment_sort_key_computer(segment_reader)?; + let sort_key_computer2 = self.1.segment_sort_key_computer(segment_reader)?; + let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?; + let sort_key_computer4 = self.3.segment_sort_key_computer(segment_reader)?; + Ok(MappedSegmentSortKeyComputer { + sort_key_computer: ( + sort_key_computer1, + (sort_key_computer2, (sort_key_computer3, sort_key_computer4)), + ), + map: |(sort_key1, (sort_key2, (sort_key3, sort_key4)))| { + (sort_key1, sort_key2, sort_key3, sort_key4) + }, + }) + } + + fn check_schema(&self, schema: &Schema) -> crate::Result<()> { + self.0.check_schema(schema)?; + self.1.check_schema(schema)?; + self.2.check_schema(schema)?; + self.3.check_schema(schema)?; + Ok(()) + } + + fn requires_scoring(&self) -> bool { + self.0.requires_scoring() + || self.1.requires_scoring() + || self.2.requires_scoring() + || self.3.requires_scoring() + } +} + +impl SortKeyComputer for F +where + F: 'static + Send + Sync + Fn(&SegmentReader) -> SegmentF, + SegmentF: 'static + FnMut(DocId) -> TSortKey, + TSortKey: 'static + PartialOrd + Clone + Send + Sync + std::fmt::Debug, +{ + type SortKey = TSortKey; + type Child = SegmentF; + type Comparator = NaturalComparator; + + fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result { + Ok((self)(segment_reader)) + } +} + +impl SegmentSortKeyComputer for F +where + F: 'static + FnMut(DocId) -> TSortKey, + TSortKey: 'static + PartialOrd + Clone + Send + Sync, +{ + type SortKey = TSortKey; + type SegmentSortKey = TSortKey; + + fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> TSortKey { + (self)(doc) + } + + /// Convert a segment level score into the global level score. + fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey { + sort_key + } +} + +#[cfg(test)] +mod tests { + use std::cmp::Ordering; + use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering}; + use std::sync::Arc; + + use crate::collector::{SegmentSortKeyComputer, SortKeyComputer}; + use crate::schema::Schema; + use crate::{DocId, Index, Order, SegmentReader}; + + fn build_test_index() -> Index { + let schema = Schema::builder().build(); + let index = Index::create_in_ram(schema); + let mut index_writer = index.writer_for_tests().unwrap(); + index_writer + .add_document(crate::TantivyDocument::default()) + .unwrap(); + index_writer.commit().unwrap(); + index + } + + #[test] + fn test_lazy_score_computer() { + let score_computer_primary = |_segment_reader: &SegmentReader| |_doc: DocId| 200u32; + let call_count = Arc::new(AtomicUsize::new(0)); + let call_count_clone = call_count.clone(); + let score_computer_secondary = move |_segment_reader: &SegmentReader| { + let call_count_new_clone = call_count_clone.clone(); + move |_doc: DocId| { + call_count_new_clone.fetch_add(1, AtomicOrdering::SeqCst); + "b" + } + }; + let lazy_score_computer = (score_computer_primary, score_computer_secondary); + let index = build_test_index(); + let searcher = index.reader().unwrap().searcher(); + let mut segment_sort_key_computer = lazy_score_computer + .segment_sort_key_computer(searcher.segment_reader(0)) + .unwrap(); + let expected_sort_key = (200, "b"); + { + let sort_key_opt = + segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(100u32, "a")); + assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key))); + assert_eq!(call_count.load(AtomicOrdering::SeqCst), 1); + } + { + let sort_key_opt = + segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(100u32, "c")); + assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key))); + assert_eq!(call_count.load(AtomicOrdering::SeqCst), 2); + } + { + let sort_key_opt = + segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(200u32, "a")); + assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key))); + assert_eq!(call_count.load(AtomicOrdering::SeqCst), 3); + } + { + let sort_key_opt = + segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(200u32, "c")); + assert!(sort_key_opt.is_none()); + assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4); + } + { + let sort_key_opt = + segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(300u32, "a")); + assert_eq!(sort_key_opt, None); + assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4); + } + { + let sort_key_opt = + segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(300u32, "c")); + assert_eq!(sort_key_opt, None); + assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4); + } + { + let sort_key_opt = + segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &expected_sort_key); + assert_eq!(sort_key_opt, Some((Ordering::Equal, expected_sort_key))); + assert_eq!(call_count.load(AtomicOrdering::SeqCst), 5); + } + } + + #[test] + fn test_lazy_score_computer_dynamic_ordering() { + let score_computer_primary = |_segment_reader: &SegmentReader| |_doc: DocId| 200u32; + let call_count = Arc::new(AtomicUsize::new(0)); + let call_count_clone = call_count.clone(); + let score_computer_secondary = move |_segment_reader: &SegmentReader| { + let call_count_new_clone = call_count_clone.clone(); + move |_doc: DocId| { + call_count_new_clone.fetch_add(1, AtomicOrdering::SeqCst); + 2u32 + } + }; + let lazy_score_computer = ( + (score_computer_primary, Order::Desc), + (score_computer_secondary, Order::Asc), + ); + let index = build_test_index(); + let searcher = index.reader().unwrap().searcher(); + let mut segment_sort_key_computer = lazy_score_computer + .segment_sort_key_computer(searcher.segment_reader(0)) + .unwrap(); + let expected_sort_key = (200, 2u32); + + { + let sort_key_opt = + segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(100u32, 1u32)); + assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key))); + assert_eq!(call_count.load(AtomicOrdering::SeqCst), 1); + } + { + let sort_key_opt = + segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(100u32, 3u32)); + assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key))); + assert_eq!(call_count.load(AtomicOrdering::SeqCst), 2); + } + { + let sort_key_opt = + segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(200u32, 1u32)); + assert!(sort_key_opt.is_none()); + assert_eq!(call_count.load(AtomicOrdering::SeqCst), 3); + } + { + let sort_key_opt = + segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(200u32, 3u32)); + assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key))); + assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4); + } + { + let sort_key_opt = + segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(300u32, 1u32)); + assert_eq!(sort_key_opt, None); + assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4); + } + { + let sort_key_opt = + segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(300u32, 3u32)); + assert_eq!(sort_key_opt, None); + assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4); + } + { + let sort_key_opt = + segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &expected_sort_key); + assert_eq!(sort_key_opt, Some((Ordering::Equal, expected_sort_key))); + assert_eq!(call_count.load(AtomicOrdering::SeqCst), 5); + } + assert_eq!( + segment_sort_key_computer.convert_segment_sort_key(expected_sort_key), + (200u32, 2u32) + ); + } +} diff --git a/src/collector/sort_key_top_collector.rs b/src/collector/sort_key_top_collector.rs new file mode 100644 index 000000000..3ca27fc75 --- /dev/null +++ b/src/collector/sort_key_top_collector.rs @@ -0,0 +1,193 @@ +use std::ops::Range; + +use crate::collector::sort_key::{Comparator, SegmentSortKeyComputer, SortKeyComputer}; +use crate::collector::{Collector, SegmentCollector, TopNComputer}; +use crate::query::Weight; +use crate::schema::Schema; +use crate::{DocAddress, DocId, Result, Score, SegmentReader}; + +pub(crate) struct TopBySortKeyCollector { + sort_key_computer: TSortKeyComputer, + doc_range: Range, +} + +impl TopBySortKeyCollector { + pub fn new(sort_key_computer: TSortKeyComputer, doc_range: Range) -> Self { + TopBySortKeyCollector { + sort_key_computer, + doc_range, + } + } +} + +impl Collector for TopBySortKeyCollector +where TSortKeyComputer: SortKeyComputer + Send + Sync + 'static +{ + type Fruit = Vec<(TSortKeyComputer::SortKey, DocAddress)>; + + type Child = + TopBySortKeySegmentCollector; + + fn check_schema(&self, schema: &Schema) -> crate::Result<()> { + self.sort_key_computer.check_schema(schema) + } + + fn for_segment(&self, segment_ord: u32, segment_reader: &SegmentReader) -> Result { + let segment_sort_key_computer = self + .sort_key_computer + .segment_sort_key_computer(segment_reader)?; + let topn_computer = TopNComputer::new_with_comparator( + self.doc_range.end, + self.sort_key_computer.comparator(), + ); + Ok(TopBySortKeySegmentCollector { + topn_computer, + segment_ord, + segment_sort_key_computer, + }) + } + + fn requires_scoring(&self) -> bool { + self.sort_key_computer.requires_scoring() + } + + fn merge_fruits(&self, segment_fruits: Vec) -> Result { + Ok(merge_top_k( + segment_fruits.into_iter().flatten(), + self.doc_range.clone(), + self.sort_key_computer.comparator(), + )) + } + + fn collect_segment( + &self, + weight: &dyn Weight, + segment_ord: u32, + reader: &SegmentReader, + ) -> crate::Result> { + let k = self.doc_range.end; + let docs = self + .sort_key_computer + .collect_segment_top_k(k, weight, reader, segment_ord)?; + Ok(docs) + } +} + +fn merge_top_k>( + sort_key_docs: impl Iterator, + doc_range: Range, + comparator: C, +) -> Vec<(TSortKey, D)> { + if doc_range.is_empty() { + return Vec::new(); + } + let mut top_collector: TopNComputer = + TopNComputer::new_with_comparator(doc_range.end, comparator); + for (sort_key, doc) in sort_key_docs { + top_collector.push(sort_key, doc); + } + top_collector + .into_sorted_vec() + .into_iter() + .skip(doc_range.start) + .map(|cdoc| (cdoc.sort_key, cdoc.doc)) + .collect() +} + +pub struct TopBySortKeySegmentCollector +where + TSegmentSortKeyComputer: SegmentSortKeyComputer, + C: Comparator, +{ + pub(crate) topn_computer: TopNComputer, + pub(crate) segment_ord: u32, + pub(crate) segment_sort_key_computer: TSegmentSortKeyComputer, +} + +impl SegmentCollector + for TopBySortKeySegmentCollector +where + TSegmentSortKeyComputer: 'static + SegmentSortKeyComputer, + C: Comparator + 'static, +{ + type Fruit = Vec<(TSegmentSortKeyComputer::SortKey, DocAddress)>; + + fn collect(&mut self, doc: DocId, score: Score) { + self.segment_sort_key_computer.compute_sort_key_and_collect( + doc, + score, + &mut self.topn_computer, + ); + } + + fn harvest(self) -> Self::Fruit { + let segment_ord = self.segment_ord; + let segment_hits: Vec<(TSegmentSortKeyComputer::SortKey, DocAddress)> = self + .topn_computer + .into_vec() + .into_iter() + .map(|comparable_doc| { + let sort_key = self + .segment_sort_key_computer + .convert_segment_sort_key(comparable_doc.sort_key); + ( + sort_key, + DocAddress { + segment_ord, + doc_id: comparable_doc.doc, + }, + ) + }) + .collect(); + segment_hits + } +} + +#[cfg(test)] +mod tests { + use std::ops::Range; + + use rand; + use rand::seq::SliceRandom as _; + + use super::merge_top_k; + use crate::collector::sort_key::ComparatorEnum; + use crate::Order; + + fn test_merge_top_k_aux( + order: Order, + doc_range: Range, + expected: &[(crate::Score, usize)], + ) { + let mut vals: Vec<(crate::Score, usize)> = (0..10).map(|val| (val as f32, val)).collect(); + vals.shuffle(&mut rand::thread_rng()); + let vals_merged = merge_top_k(vals.into_iter(), doc_range, ComparatorEnum::from(order)); + assert_eq!(&vals_merged, expected); + } + + #[test] + fn test_merge_top_k() { + test_merge_top_k_aux(Order::Asc, 0..0, &[]); + test_merge_top_k_aux(Order::Asc, 3..3, &[]); + test_merge_top_k_aux(Order::Asc, 0..3, &[(0.0f32, 0), (1.0f32, 1), (2.0f32, 2)]); + test_merge_top_k_aux( + Order::Asc, + 0..11, + &[ + (0.0f32, 0), + (1.0f32, 1), + (2.0f32, 2), + (3.0f32, 3), + (4.0f32, 4), + (5.0f32, 5), + (6.0f32, 6), + (7.0f32, 7), + (8.0f32, 8), + (9.0f32, 9), + ], + ); + test_merge_top_k_aux(Order::Asc, 1..3, &[(1.0f32, 1), (2.0f32, 2)]); + test_merge_top_k_aux(Order::Desc, 0..2, &[(9.0f32, 9), (8.0f32, 8)]); + test_merge_top_k_aux(Order::Desc, 2..4, &[(7.0f32, 7), (6.0f32, 6)]); + } +} diff --git a/src/collector/tests.rs b/src/collector/tests.rs index 7af7c6d8c..61b6a595b 100644 --- a/src/collector/tests.rs +++ b/src/collector/tests.rs @@ -40,7 +40,7 @@ pub fn test_filter_collector() -> crate::Result<()> { let filter_some_collector = FilterCollector::new( "price".to_string(), &|value: u64| value > 20_120u64, - TopDocs::with_limit(2), + TopDocs::with_limit(2).order_by_score(), ); let top_docs = searcher.search(&query, &filter_some_collector)?; @@ -50,7 +50,7 @@ pub fn test_filter_collector() -> crate::Result<()> { let filter_all_collector: FilterCollector<_, _, u64> = FilterCollector::new( "price".to_string(), &|value| value < 5u64, - TopDocs::with_limit(2), + TopDocs::with_limit(2).order_by_score(), ); let filtered_top_docs = searcher.search(&query, &filter_all_collector).unwrap(); @@ -62,8 +62,11 @@ pub fn test_filter_collector() -> crate::Result<()> { > 0 } - let filter_dates_collector = - FilterCollector::new("date".to_string(), &date_filter, TopDocs::with_limit(5)); + let filter_dates_collector = FilterCollector::new( + "date".to_string(), + &date_filter, + TopDocs::with_limit(5).order_by_score(), + ); let filtered_date_docs = searcher.search(&query, &filter_dates_collector)?; assert_eq!(filtered_date_docs.len(), 2); diff --git a/src/collector/top_collector.rs b/src/collector/top_collector.rs index 29ff08600..6981c86c9 100644 --- a/src/collector/top_collector.rs +++ b/src/collector/top_collector.rs @@ -1,12 +1,7 @@ use std::cmp::Ordering; -use std::marker::PhantomData; use serde::{Deserialize, Serialize}; -use super::top_score_collector::TopNComputer; -use crate::index::SegmentReader; -use crate::{DocAddress, DocId, SegmentOrdinal}; - /// Contains a feature (field, score, etc.) of a document along with the document address. /// /// It guarantees stable sorting: in case of a tie on the feature, the document @@ -19,7 +14,7 @@ use crate::{DocAddress, DocId, SegmentOrdinal}; pub struct ComparableDoc { /// The feature of the document. In practice, this is /// is any type that implements `PartialOrd`. - pub feature: T, + pub sort_key: T, /// The document address. In practice, this is any /// type that implements `PartialOrd`, and is guaranteed /// to be unique for each document. @@ -28,9 +23,9 @@ pub struct ComparableDoc { impl std::fmt::Debug for ComparableDoc { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { f.debug_struct(format!("ComparableDoc<_, _ {R}").as_str()) - .field("feature", &self.feature) + .field("feature", &self.sort_key) .field("doc", &self.doc) .finish() } @@ -46,8 +41,8 @@ impl Ord for ComparableDoc #[inline] fn cmp(&self, other: &Self) -> Ordering { let by_feature = self - .feature - .partial_cmp(&other.feature) + .sort_key + .partial_cmp(&other.sort_key) .map(|ord| if R { ord.reverse() } else { ord }) .unwrap_or(Ordering::Equal); @@ -67,308 +62,3 @@ impl PartialEq for ComparableDoc Eq for ComparableDoc {} - -pub(crate) struct TopCollector { - pub limit: usize, - pub offset: usize, - _marker: PhantomData, -} - -impl TopCollector -where T: PartialOrd + Clone -{ - /// Creates a top collector, with a number of documents equal to "limit". - /// - /// # Panics - /// The method panics if limit is 0 - pub fn with_limit(limit: usize) -> TopCollector { - assert!(limit >= 1, "Limit must be strictly greater than 0."); - Self { - limit, - offset: 0, - _marker: PhantomData, - } - } - - /// Skip the first "offset" documents when collecting. - /// - /// This is equivalent to `OFFSET` in MySQL or PostgreSQL and `start` in - /// Lucene's TopDocsCollector. - pub fn and_offset(mut self, offset: usize) -> TopCollector { - self.offset = offset; - self - } - - pub fn merge_fruits( - &self, - children: Vec>, - ) -> crate::Result> { - if self.limit == 0 { - return Ok(Vec::new()); - } - let mut top_collector: TopNComputer<_, _> = TopNComputer::new(self.limit + self.offset); - for child_fruit in children { - for (feature, doc) in child_fruit { - top_collector.push(feature, doc); - } - } - - Ok(top_collector - .into_sorted_vec() - .into_iter() - .skip(self.offset) - .map(|cdoc| (cdoc.feature, cdoc.doc)) - .collect()) - } - - pub(crate) fn for_segment( - &self, - segment_id: SegmentOrdinal, - _: &SegmentReader, - ) -> TopSegmentCollector { - TopSegmentCollector::new(segment_id, self.limit + self.offset) - } - - /// Create a new TopCollector with the same limit and offset. - /// - /// Ideally we would use Into but the blanket implementation seems to cause the Scorer traits - /// to fail. - #[doc(hidden)] - pub(crate) fn into_tscore(self) -> TopCollector { - TopCollector { - limit: self.limit, - offset: self.offset, - _marker: PhantomData, - } - } -} - -/// The Top Collector keeps track of the K documents -/// sorted by type `T`. -/// -/// The implementation is based on a repeatedly truncating on the median after K * 2 documents -/// The theoretical complexity for collecting the top `K` out of `n` documents -/// is `O(n + K)`. -pub(crate) struct TopSegmentCollector { - /// We reverse the order of the feature in order to - /// have top-semantics instead of bottom semantics. - topn_computer: TopNComputer, - segment_ord: u32, -} - -impl TopSegmentCollector { - fn new(segment_ord: SegmentOrdinal, limit: usize) -> TopSegmentCollector { - TopSegmentCollector { - topn_computer: TopNComputer::new(limit), - segment_ord, - } - } -} - -impl TopSegmentCollector { - pub fn harvest(self) -> Vec<(T, DocAddress)> { - let segment_ord = self.segment_ord; - self.topn_computer - .into_sorted_vec() - .into_iter() - .map(|comparable_doc| { - ( - comparable_doc.feature, - DocAddress { - segment_ord, - doc_id: comparable_doc.doc, - }, - ) - }) - .collect() - } - - /// Collects a document scored by the given feature - /// - /// It collects documents until it has reached the max capacity. Once it reaches capacity, it - /// will compare the lowest scoring item with the given one and keep whichever is greater. - #[inline] - pub fn collect(&mut self, doc: DocId, feature: T) { - self.topn_computer.push(feature, doc); - } -} - -#[cfg(test)] -mod tests { - use super::{TopCollector, TopSegmentCollector}; - use crate::DocAddress; - - #[test] - fn test_top_collector_not_at_capacity() { - let mut top_collector = TopSegmentCollector::new(0, 4); - top_collector.collect(1, 0.8); - top_collector.collect(3, 0.2); - top_collector.collect(5, 0.3); - assert_eq!( - top_collector.harvest(), - vec![ - (0.8, DocAddress::new(0, 1)), - (0.3, DocAddress::new(0, 5)), - (0.2, DocAddress::new(0, 3)) - ] - ); - } - - #[test] - fn test_top_collector_at_capacity() { - let mut top_collector = TopSegmentCollector::new(0, 4); - top_collector.collect(1, 0.8); - top_collector.collect(3, 0.2); - top_collector.collect(5, 0.3); - top_collector.collect(7, 0.9); - top_collector.collect(9, -0.2); - assert_eq!( - top_collector.harvest(), - vec![ - (0.9, DocAddress::new(0, 7)), - (0.8, DocAddress::new(0, 1)), - (0.3, DocAddress::new(0, 5)), - (0.2, DocAddress::new(0, 3)) - ] - ); - } - - #[test] - fn test_top_segment_collector_stable_ordering_for_equal_feature() { - // given that the documents are collected in ascending doc id order, - // when harvesting we have to guarantee stable sorting in case of a tie - // on the score - let doc_ids_collection = [4, 5, 6]; - let score = 3.3f32; - - let mut top_collector_limit_2 = TopSegmentCollector::new(0, 2); - for id in &doc_ids_collection { - top_collector_limit_2.collect(*id, score); - } - - let mut top_collector_limit_3 = TopSegmentCollector::new(0, 3); - for id in &doc_ids_collection { - top_collector_limit_3.collect(*id, score); - } - - assert_eq!( - top_collector_limit_2.harvest(), - top_collector_limit_3.harvest()[..2].to_vec(), - ); - } - - #[test] - fn test_top_collector_with_limit_and_offset() { - let collector = TopCollector::with_limit(2).and_offset(1); - - let results = collector - .merge_fruits(vec![vec![ - (0.9, DocAddress::new(0, 1)), - (0.8, DocAddress::new(0, 2)), - (0.7, DocAddress::new(0, 3)), - (0.6, DocAddress::new(0, 4)), - (0.5, DocAddress::new(0, 5)), - ]]) - .unwrap(); - - assert_eq!( - results, - vec![(0.8, DocAddress::new(0, 2)), (0.7, DocAddress::new(0, 3)),] - ); - } - - #[test] - fn test_top_collector_with_limit_larger_than_set_and_offset() { - let collector = TopCollector::with_limit(2).and_offset(1); - - let results = collector - .merge_fruits(vec![vec![ - (0.9, DocAddress::new(0, 1)), - (0.8, DocAddress::new(0, 2)), - ]]) - .unwrap(); - - assert_eq!(results, vec![(0.8, DocAddress::new(0, 2)),]); - } - - #[test] - fn test_top_collector_with_limit_and_offset_larger_than_set() { - let collector = TopCollector::with_limit(2).and_offset(20); - - let results = collector - .merge_fruits(vec![vec![ - (0.9, DocAddress::new(0, 1)), - (0.8, DocAddress::new(0, 2)), - ]]) - .unwrap(); - - assert_eq!(results, vec![]); - } -} - -#[cfg(all(test, feature = "unstable"))] -mod bench { - use test::Bencher; - - use super::TopSegmentCollector; - - #[bench] - fn bench_top_segment_collector_collect_not_at_capacity(b: &mut Bencher) { - let mut top_collector = TopSegmentCollector::new(0, 400); - - b.iter(|| { - for i in 0..100 { - top_collector.collect(i, 0.8); - } - }); - } - - #[bench] - fn bench_top_segment_collector_collect_at_capacity(b: &mut Bencher) { - let mut top_collector = TopSegmentCollector::new(0, 100); - - for i in 0..100 { - top_collector.collect(i, 0.8); - } - - b.iter(|| { - for i in 0..100 { - top_collector.collect(i, 0.8); - } - }); - } - - #[bench] - fn bench_top_segment_collector_collect_and_harvest_many_ties(b: &mut Bencher) { - b.iter(|| { - let mut top_collector = TopSegmentCollector::new(0, 100); - - for i in 0..100 { - top_collector.collect(i, 0.8); - } - - // it would be nice to be able to do the setup N times but still - // measure only harvest(). We can't since harvest() consumes - // the top_collector. - top_collector.harvest() - }); - } - - #[bench] - fn bench_top_segment_collector_collect_and_harvest_no_tie(b: &mut Bencher) { - b.iter(|| { - let mut top_collector = TopSegmentCollector::new(0, 100); - let mut score = 1.0; - - for i in 0..100 { - score += 1.0; - top_collector.collect(i, score); - } - - // it would be nice to be able to do the setup N times but still - // measure only harvest(). We can't since harvest() consumes - // the top_collector. - top_collector.harvest() - }); - } -} diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index 33c5df59e..78c344dbe 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -1,247 +1,19 @@ +use std::cmp::Ordering; use std::fmt; -use std::marker::PhantomData; -use std::sync::Arc; +use std::ops::Range; -use columnar::{ColumnValues, StrColumn}; use serde::{Deserialize, Serialize}; use super::Collector; -use crate::collector::custom_score_top_collector::{ - CustomScoreTopCollector, CustomScoreTopSegmentCollector, +use crate::collector::sort_key::{ + Comparator, ComparatorEnum, NaturalComparator, ReverseComparator, SortBySimilarityScore, + SortByStaticFastValue, SortByString, }; -use crate::collector::top_collector::{ComparableDoc, TopCollector, TopSegmentCollector}; -use crate::collector::tweak_score_top_collector::TweakedScoreTopCollector; -use crate::collector::{ - CustomScorer, CustomSegmentScorer, ScoreSegmentTweaker, ScoreTweaker, SegmentCollector, -}; -use crate::fastfield::{FastFieldNotAvailableError, FastValue}; -use crate::query::Weight; -use crate::termdict::TermOrdinal; -use crate::{DocAddress, DocId, Order, Score, SegmentOrdinal, SegmentReader, TantivyError}; - -struct FastFieldConvertCollector< - TCollector: Collector>, - TFastValue: FastValue, -> { - pub collector: TCollector, - pub field: String, - pub fast_value: std::marker::PhantomData, - order: Order, -} - -impl Collector for FastFieldConvertCollector -where - TCollector: Collector>, - TFastValue: FastValue, -{ - type Fruit = Vec<(TFastValue, DocAddress)>; - - type Child = TCollector::Child; - - fn for_segment( - &self, - segment_local_id: crate::SegmentOrdinal, - segment: &SegmentReader, - ) -> crate::Result { - let schema = segment.schema(); - let field = schema.get_field(&self.field)?; - let field_entry = schema.get_field_entry(field); - if !field_entry.is_fast() { - return Err(TantivyError::SchemaError(format!( - "Field {:?} is not a fast field.", - field_entry.name() - ))); - } - let schema_type = TFastValue::to_type(); - let requested_type = field_entry.field_type().value_type(); - if schema_type != requested_type { - return Err(TantivyError::SchemaError(format!( - "Field {:?} is of type {schema_type:?}!={requested_type:?}", - field_entry.name() - ))); - } - self.collector.for_segment(segment_local_id, segment) - } - - fn requires_scoring(&self) -> bool { - self.collector.requires_scoring() - } - - fn merge_fruits( - &self, - segment_fruits: Vec<::Fruit>, - ) -> crate::Result { - let raw_result = self.collector.merge_fruits(segment_fruits)?; - let transformed_result = raw_result - .into_iter() - .map(|(score, doc_address)| { - if self.order.is_desc() { - (TFastValue::from_u64(score), doc_address) - } else { - (TFastValue::from_u64(u64::MAX - score), doc_address) - } - }) - .collect::>(); - Ok(transformed_result) - } -} - -struct StringConvertCollector { - pub collector: CustomScoreTopCollector, - pub field: String, - order: Order, - limit: usize, - offset: usize, -} - -impl Collector for StringConvertCollector { - type Fruit = Vec<(String, DocAddress)>; - - type Child = StringConvertSegmentCollector; - - fn for_segment( - &self, - segment_local_id: crate::SegmentOrdinal, - segment: &SegmentReader, - ) -> crate::Result { - let schema = segment.schema(); - let field = schema.get_field(&self.field)?; - let field_entry = schema.get_field_entry(field); - if !field_entry.is_fast() { - return Err(TantivyError::SchemaError(format!( - "Field {:?} is not a fast field.", - field_entry.name() - ))); - } - let requested_type = crate::schema::Type::Str; - let schema_type = field_entry.field_type().value_type(); - if schema_type != requested_type { - return Err(TantivyError::SchemaError(format!( - "Field {:?} is of type {schema_type:?}!={requested_type:?}", - field_entry.name() - ))); - } - let ff = segment - .fast_fields() - .str(&self.field)? - .expect("ff should be a str field"); - Ok(StringConvertSegmentCollector { - collector: self.collector.for_segment(segment_local_id, segment)?, - ff, - order: self.order.clone(), - }) - } - - fn requires_scoring(&self) -> bool { - self.collector.requires_scoring() - } - - fn merge_fruits( - &self, - child_fruits: Vec<::Fruit>, - ) -> crate::Result { - if self.limit == 0 { - return Ok(Vec::new()); - } - if self.order.is_desc() { - let mut top_collector: TopNComputer<_, _, true> = - TopNComputer::new(self.limit + self.offset); - for child_fruit in child_fruits { - for (feature, doc) in child_fruit { - top_collector.push(feature, doc); - } - } - Ok(top_collector - .into_sorted_vec() - .into_iter() - .skip(self.offset) - .map(|cdoc| (cdoc.feature, cdoc.doc)) - .collect()) - } else { - let mut top_collector: TopNComputer<_, _, false> = - TopNComputer::new(self.limit + self.offset); - for child_fruit in child_fruits { - for (feature, doc) in child_fruit { - top_collector.push(feature, doc); - } - } - - Ok(top_collector - .into_sorted_vec() - .into_iter() - .skip(self.offset) - .map(|cdoc| (cdoc.feature, cdoc.doc)) - .collect()) - } - } -} - -struct StringConvertSegmentCollector { - pub collector: CustomScoreTopSegmentCollector, - ff: StrColumn, - order: Order, -} - -impl SegmentCollector for StringConvertSegmentCollector { - type Fruit = Vec<(String, DocAddress)>; - - fn collect(&mut self, doc: DocId, score: Score) { - self.collector.collect(doc, score); - } - - fn harvest(self) -> Vec<(String, DocAddress)> { - let top_ordinals: Vec<(TermOrdinal, DocAddress)> = self.collector.harvest(); - - // Collect terms. - let mut terms: Vec = Vec::with_capacity(top_ordinals.len()); - let result = if self.order.is_asc() { - self.ff.dictionary().sorted_ords_to_term_cb( - top_ordinals.iter().map(|(term_ord, _)| u64::MAX - term_ord), - |term| { - terms.push( - std::str::from_utf8(term) - .expect("Failed to decode term as unicode") - .to_owned(), - ); - Ok(()) - }, - ) - } else { - self.ff.dictionary().sorted_ords_to_term_cb( - top_ordinals.iter().rev().map(|(term_ord, _)| *term_ord), - |term| { - terms.push( - std::str::from_utf8(term) - .expect("Failed to decode term as unicode") - .to_owned(), - ); - Ok(()) - }, - ) - }; - - assert!( - result.expect("Failed to read terms from term dictionary"), - "Not all terms were matched in segment." - ); - - // Zip them back with their docs. - if self.order.is_asc() { - terms - .into_iter() - .zip(top_ordinals) - .map(|(term, (_, doc))| (term, doc)) - .collect() - } else { - terms - .into_iter() - .rev() - .zip(top_ordinals) - .map(|(term, (_, doc))| (term, doc)) - .collect() - } - } -} +use crate::collector::sort_key_top_collector::TopBySortKeyCollector; +use crate::collector::top_collector::ComparableDoc; +use crate::collector::{SegmentSortKeyComputer, SortKeyComputer}; +use crate::fastfield::FastValue; +use crate::{DocAddress, DocId, Order, Score, SegmentReader}; /// The `TopDocs` collector keeps track of the top `K` documents /// sorted by their score. @@ -280,78 +52,48 @@ impl SegmentCollector for StringConvertSegmentCollector { /// /// let query_parser = QueryParser::for_index(&index, vec![title]); /// let query = query_parser.parse_query("diary")?; -/// let top_docs = searcher.search(&query, &TopDocs::with_limit(2))?; +/// let top_docs = searcher.search(&query, &TopDocs::with_limit(2).order_by_score())?; /// /// assert_eq!(top_docs[0].1, DocAddress::new(0, 1)); /// assert_eq!(top_docs[1].1, DocAddress::new(0, 3)); /// # Ok(()) /// # } /// ``` -pub struct TopDocs(TopCollector); +pub struct TopDocs { + limit: usize, + offset: usize, +} impl fmt::Debug for TopDocs { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "TopDocs(limit={}, offset={})", - self.0.limit, self.0.offset - ) - } -} - -struct ScorerByFastFieldReader { - sort_column: Arc>, - order: Order, -} - -impl CustomSegmentScorer for ScorerByFastFieldReader { - fn score(&mut self, doc: DocId) -> u64 { - let value = self.sort_column.get_val(doc); - if self.order.is_desc() { - value - } else { - u64::MAX - value - } - } -} - -struct ScorerByField { - field: String, - order: Order, -} - -impl CustomScorer for ScorerByField { - type Child = ScorerByFastFieldReader; - - fn segment_scorer(&self, segment_reader: &SegmentReader) -> crate::Result { - // We interpret this field as u64, regardless of its type, that way, - // we avoid needless conversion. Regardless of the fast field type, the - // mapping is monotonic, so it is sufficient to compute our top-K docs. - // - // The conversion will then happen only on the top-K docs. - let sort_column_opt = segment_reader.fast_fields().u64_lenient(&self.field)?; - let (sort_column, _sort_column_type) = - sort_column_opt.ok_or_else(|| FastFieldNotAvailableError { - field_name: self.field.clone(), - })?; - let mut default_value = 0u64; - if self.order.is_asc() { - default_value = u64::MAX; - } - Ok(ScorerByFastFieldReader { - sort_column: sort_column.first_or_default_col(default_value), - order: self.order.clone(), - }) + write!(f, "TopDocs(limit={}, offset={})", self.limit, self.offset) } } impl TopDocs { + /// Builds a `TopDocs` capturing a given document range. + /// + /// The range start..end translates in a limit of `end - start` + /// and an offset of start. + pub fn for_doc_range(doc_range: Range) -> Self { + TopDocs { + limit: doc_range.end.saturating_sub(doc_range.start), + offset: doc_range.start, + } + } + + /// Returns the doc range we are trying to capture. + pub fn doc_range(&self) -> Range { + self.offset..self.offset + self.limit + } + /// Creates a top score collector, with a number of documents equal to "limit". /// /// # Panics /// The method panics if limit is 0 pub fn with_limit(limit: usize) -> TopDocs { - TopDocs(TopCollector::with_limit(limit)) + assert_ne!(limit, 0, "Limit must be greater than 0"); + TopDocs { limit, offset: 0 } } /// Skip the first "offset" documents when collecting. @@ -386,7 +128,7 @@ impl TopDocs { /// /// let query_parser = QueryParser::for_index(&index, vec![title]); /// let query = query_parser.parse_query("diary")?; - /// let top_docs = searcher.search(&query, &TopDocs::with_limit(2).and_offset(1))?; + /// let top_docs = searcher.search(&query, &TopDocs::with_limit(2).and_offset(1).order_by_score())?; /// /// assert_eq!(top_docs.len(), 2); /// assert_eq!(top_docs[0].1, DocAddress::new(0, 4)); @@ -396,7 +138,10 @@ impl TopDocs { /// ``` #[must_use] pub fn and_offset(self, offset: usize) -> TopDocs { - TopDocs(self.0.and_offset(offset)) + TopDocs { + limit: self.limit, + offset, + } } /// Set top-K to rank documents by a given fast field. @@ -434,8 +179,8 @@ impl TopDocs { /// # let query = QueryParser::for_index(&index, vec![title]).parse_query("diary")?; /// # let top_docs = docs_sorted_by_rating(&reader.searcher(), &query)?; /// # assert_eq!(top_docs, - /// # vec![(97u64, DocAddress::new(0u32, 1)), - /// # (80u64, DocAddress::new(0u32, 3))]); + /// # vec![(Some(97u64), DocAddress::new(0u32, 1)), + /// # (Some(80u64), DocAddress::new(0u32, 3))]); /// # Ok(()) /// # } /// /// Searches the document matching the given query, and @@ -443,7 +188,7 @@ impl TopDocs { /// /// given in argument. /// fn docs_sorted_by_rating(searcher: &Searcher, /// query: &dyn Query) - /// -> tantivy::Result> { + /// -> tantivy::Result, DocAddress)>> { /// /// // This is where we build our topdocs collector /// // @@ -459,7 +204,7 @@ impl TopDocs { /// // The vec is sorted decreasingly by `sort_by_field`, and has a /// // length of 10, or less if not enough documents matched the /// // query. - /// let resulting_docs: Vec<(u64, DocAddress)> = + /// let resulting_docs: Vec<(Option, DocAddress)> = /// searcher.search(query, &top_books_by_rating)?; /// /// Ok(resulting_docs) @@ -474,14 +219,13 @@ impl TopDocs { self, field: impl ToString, order: Order, - ) -> impl Collector> { - CustomScoreTopCollector::new( - ScorerByField { - field: field.to_string(), - order, - }, - self.0.into_tscore(), - ) + ) -> impl Collector, DocAddress)>> { + self.order_by((SortByStaticFastValue::for_field(field), order)) + } + + /// Order docs by decreasing BM25 similarity score. + pub fn order_by_score(self) -> impl Collector> { + TopBySortKeyCollector::new(SortBySimilarityScore, self.doc_range()) } /// Set top-K to rank documents by a given fast field. @@ -520,8 +264,8 @@ impl TopDocs { /// # let reader = index.reader()?; /// # let top_docs = docs_sorted_by_revenue(&reader.searcher(), &AllQuery, "revenue")?; /// # assert_eq!(top_docs, - /// # vec![(119_000_000i64, DocAddress::new(0, 1)), - /// # (92_000_000i64, DocAddress::new(0, 0))]); + /// # vec![(Some(119_000_000i64), DocAddress::new(0, 1)), + /// # (Some(92_000_000i64), DocAddress::new(0, 0))]); /// # Ok(()) /// # } /// /// Searches the document matching the given query, and @@ -530,7 +274,7 @@ impl TopDocs { /// fn docs_sorted_by_revenue(searcher: &Searcher, /// query: &dyn Query, /// revenue_field: &str) - /// -> tantivy::Result> { + /// -> tantivy::Result, DocAddress)>> { /// /// // This is where we build our topdocs collector /// // @@ -547,7 +291,7 @@ impl TopDocs { /// // The vec is sorted decreasingly by `sort_by_field`, and has a /// // length of 10, or less if not enough documents matched the /// // query. - /// let resulting_docs: Vec<(i64, DocAddress)> = + /// let resulting_docs: Vec<(Option, DocAddress)> = /// searcher.search(query, &top_company_by_revenue)?; /// /// Ok(resulting_docs) @@ -557,17 +301,12 @@ impl TopDocs { self, fast_field: impl ToString, order: Order, - ) -> impl Collector> + ) -> impl Collector, DocAddress)>> where TFastValue: FastValue, + ComparatorEnum: Comparator>, { - let u64_collector = self.order_by_u64_field(fast_field.to_string(), order.clone()); - FastFieldConvertCollector { - collector: u64_collector, - field: fast_field.to_string(), - fast_value: PhantomData, - order, - } + self.order_by((SortByStaticFastValue::for_field(fast_field), order)) } /// Like `order_by_fast_field`, but for a `String` fast field. @@ -575,30 +314,28 @@ impl TopDocs { self, fast_field: impl ToString, order: Order, - ) -> impl Collector> { - let limit = self.0.limit; - let offset = self.0.offset; - let u64_collector = CustomScoreTopCollector::new( - ScorerByField { - field: fast_field.to_string(), - order: order.clone(), - }, - self.0.into_tscore(), - ); - StringConvertCollector { - collector: u64_collector, - field: fast_field.to_string(), - order, - limit, - offset, - } + ) -> impl Collector, DocAddress)>> { + let by_string_sort_key_computer = SortByString::for_field(fast_field.to_string()); + self.order_by((by_string_sort_key_computer, order)) } - /// Ranks the documents using a custom score. + /// Ranks the documents using a sort key. + pub fn order_by( + self, + sort_key_computer: impl SortKeyComputer + Send + 'static, + ) -> impl Collector> + where + TSortKey: 'static + Clone + Send + Sync + PartialOrd + std::fmt::Debug, + { + TopBySortKeyCollector::new(sort_key_computer, self.doc_range()) + } + + /// Helper function to tweak the similarity score of documents using a function. + /// (usually a closure). /// /// This method offers a convenient way to tweak or replace /// the documents score. As suggested by the prototype you can - /// manually define your own [`ScoreTweaker`] + /// manually define your own [`SortKeyComputer`] /// and pass it as an argument, but there is a much simpler way to /// tweak your score: you can use a closure as in the following /// example. @@ -686,217 +423,72 @@ impl TopDocs { /// // The `Score` in the pair is our tweaked score. /// let resulting_docs: Vec<(Score, DocAddress)> = /// searcher.search(&query, &top_docs_by_custom_score).unwrap(); - /// ``` - /// - /// # See also - /// - [custom_score(...)](TopDocs::custom_score) - pub fn tweak_score( + /// `` + pub fn tweak_score( self, - score_tweaker: TScoreTweaker, - ) -> impl Collector> + sort_key_fn: F, + ) -> impl Collector> where - TScore: 'static + Send + Sync + Clone + PartialOrd, - TScoreSegmentTweaker: ScoreSegmentTweaker + 'static, - TScoreTweaker: ScoreTweaker + Send + Sync, + F: 'static + Send + Sync, + TSortKey: 'static + PartialOrd + Clone + Send + Sync + std::fmt::Debug, + TweakScoreFn: SortKeyComputer, { - TweakedScoreTopCollector::new(score_tweaker, self.0.into_tscore()) - } - - /// Ranks the documents using a custom score. - /// - /// This method offers a convenient way to use a different score. - /// - /// As suggested by the prototype you can manually define your own [`CustomScorer`] - /// and pass it as an argument, but there is a much simpler way to - /// tweak your score: you can use a closure as in the following - /// example. - /// - /// # Limitation - /// - /// This method only makes it possible to compute the score from a given - /// `DocId`, fastfield values for the doc and any information you could - /// have precomputed beforehand. It does not make it possible for instance - /// to compute something like TfIdf as it does not have access to the list of query - /// terms present in the document, nor the term frequencies for the different terms. - /// - /// It can be used if your search engine relies on a learning-to-rank model for instance, - /// which does not rely on the term frequencies or positions as features. - /// - /// # Example - /// - /// ```rust - /// # use tantivy::schema::{Schema, FAST, TEXT}; - /// # use tantivy::{doc, Index, DocAddress, DocId}; - /// # use tantivy::query::QueryParser; - /// use tantivy::SegmentReader; - /// use tantivy::collector::TopDocs; - /// use tantivy::schema::Field; - /// - /// # fn create_schema() -> Schema { - /// # let mut schema_builder = Schema::builder(); - /// # schema_builder.add_text_field("product_name", TEXT); - /// # schema_builder.add_u64_field("popularity", FAST); - /// # schema_builder.add_u64_field("boosted", FAST); - /// # schema_builder.build() - /// # } - /// # - /// # fn main() -> tantivy::Result<()> { - /// # let schema = create_schema(); - /// # let index = Index::create_in_ram(schema); - /// # let mut index_writer = index.writer_with_num_threads(1, 20_000_000)?; - /// # let product_name = index.schema().get_field("product_name").unwrap(); - /// # - /// let popularity: Field = index.schema().get_field("popularity").unwrap(); - /// let boosted: Field = index.schema().get_field("boosted").unwrap(); - /// # index_writer.add_document(doc!(boosted=>1u64, product_name => "The Diary of Muadib", popularity => 1u64))?; - /// # index_writer.add_document(doc!(boosted=>0u64, product_name => "A Dairy Cow", popularity => 10u64))?; - /// # index_writer.add_document(doc!(boosted=>0u64, product_name => "The Diary of a Young Girl", popularity => 15u64))?; - /// # index_writer.commit()?; - /// // ... - /// # let user_query = "diary"; - /// # let query = QueryParser::for_index(&index, vec![product_name]).parse_query(user_query)?; - /// - /// // This is where we build our collector with our custom score. - /// let top_docs_by_custom_score = TopDocs - /// ::with_limit(10) - /// .custom_score(move |segment_reader: &SegmentReader| { - /// // The argument is a function that returns our scoring - /// // function. - /// // - /// // The point of this "mother" function is to gather all - /// // of the segment level information we need for scoring. - /// // Typically, fast_fields. - /// // - /// // In our case, we will get a reader for the popularity - /// // fast field and a boosted field. - /// // - /// // We want to get boosted items score, and when we get - /// // a tie, return the item with the highest popularity. - /// // - /// // Note that this is implemented by using a `(u64, u64)` - /// // as a score. - /// let popularity_reader = - /// segment_reader.fast_fields().u64("popularity").unwrap().first_or_default_col(0); - /// let boosted_reader = - /// segment_reader.fast_fields().u64("boosted").unwrap().first_or_default_col(0); - /// - /// // We can now define our actual scoring function - /// move |doc: DocId| { - /// let popularity: u64 = popularity_reader.get_val(doc); - /// let boosted: u64 = boosted_reader.get_val(doc); - /// // Score do not have to be `f64` in tantivy. - /// // Here we return a couple to get lexicographical order - /// // for free. - /// (boosted, popularity) - /// } - /// }); - /// # let reader = index.reader()?; - /// # let searcher = reader.searcher(); - /// // ... and here are our documents. Note this is a simple vec. - /// // The `Score` in the pair is our tweaked score. - /// let resulting_docs: Vec<((u64, u64), DocAddress)> = - /// searcher.search(&*query, &top_docs_by_custom_score)?; - /// - /// # Ok(()) - /// # } - /// ``` - /// - /// # See also - /// - [tweak_score(...)](TopDocs::tweak_score) - pub fn custom_score( - self, - custom_score: TCustomScorer, - ) -> impl Collector> - where - TScore: 'static + Send + Sync + Clone + PartialOrd, - TCustomSegmentScorer: CustomSegmentScorer + 'static, - TCustomScorer: CustomScorer + Send + Sync, - { - CustomScoreTopCollector::new(custom_score, self.0.into_tscore()) + self.order_by(TweakScoreFn(sort_key_fn)) } } -impl Collector for TopDocs { - type Fruit = Vec<(Score, DocAddress)>; +/// Helper struct to make it possible to define a sort key computer that does not use +/// the similary score from a simple function. +pub struct TweakScoreFn(F); - type Child = TopScoreSegmentCollector; - - fn for_segment( - &self, - segment_local_id: SegmentOrdinal, - reader: &SegmentReader, - ) -> crate::Result { - let collector = self.0.for_segment(segment_local_id, reader); - Ok(TopScoreSegmentCollector(collector)) - } +impl SortKeyComputer for TweakScoreFn +where + F: 'static + Send + Sync + Fn(&SegmentReader) -> TTweakScoreSortKeyFn, + TTweakScoreSortKeyFn: 'static + Fn(DocId, Score) -> TSortKey, + TweakScoreSegmentSortKeyComputer: + SegmentSortKeyComputer, + TSortKey: 'static + PartialOrd + Clone + Send + Sync + std::fmt::Debug, +{ + type SortKey = TSortKey; + type Child = TweakScoreSegmentSortKeyComputer; + type Comparator = NaturalComparator; fn requires_scoring(&self) -> bool { true } - fn merge_fruits( + fn segment_sort_key_computer( &self, - child_fruits: Vec>, - ) -> crate::Result { - self.0.merge_fruits(child_fruits) - } - - fn collect_segment( - &self, - weight: &dyn Weight, - segment_ord: u32, - reader: &SegmentReader, - ) -> crate::Result<::Fruit> { - let heap_len = self.0.limit + self.0.offset; - let mut top_n: TopNComputer<_, _> = TopNComputer::new(heap_len); - - if let Some(alive_bitset) = reader.alive_bitset() { - let mut threshold = Score::MIN; - top_n.threshold = Some(threshold); - weight.for_each_pruning(Score::MIN, reader, &mut |doc, score| { - if alive_bitset.is_deleted(doc) { - return threshold; - } - top_n.push(score, doc); - threshold = top_n.threshold.unwrap_or(Score::MIN); - threshold - })?; - } else { - weight.for_each_pruning(Score::MIN, reader, &mut |doc, score| { - top_n.push(score, doc); - top_n.threshold.unwrap_or(Score::MIN) - })?; - } - - let fruit = top_n - .into_sorted_vec() - .into_iter() - .map(|cid| { - ( - cid.feature, - DocAddress { - segment_ord, - doc_id: cid.doc, - }, - ) - }) - .collect(); - Ok(fruit) + segment_reader: &SegmentReader, + ) -> crate::Result { + Ok({ + TweakScoreSegmentSortKeyComputer { + sort_key_fn: (self.0)(segment_reader), + } + }) } } -/// Segment Collector associated with `TopDocs`. -pub struct TopScoreSegmentCollector(TopSegmentCollector); +pub struct TweakScoreSegmentSortKeyComputer { + sort_key_fn: TTweakScoreSortKeyFn, +} -impl SegmentCollector for TopScoreSegmentCollector { - type Fruit = Vec<(Score, DocAddress)>; +impl SegmentSortKeyComputer + for TweakScoreSegmentSortKeyComputer +where + TTweakScoreSortKeyFn: 'static + Fn(DocId, Score) -> TSortKey, + TSortKey: 'static + PartialOrd + Clone + Send + Sync, +{ + type SortKey = TSortKey; + type SegmentSortKey = TSortKey; - fn collect(&mut self, doc: DocId, score: Score) { - self.0.collect(doc, score); + fn segment_sort_key(&mut self, doc: DocId, score: Score) -> TSortKey { + (self.sort_key_fn)(doc, score) } - fn harvest(self) -> Vec<(Score, DocAddress)> { - self.0.harvest() + /// Convert a segment level score into the global level score. + fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey { + sort_key } } @@ -907,53 +499,30 @@ impl SegmentCollector for TopScoreSegmentCollector { /// That means capacity has special meaning and should be carried over when cloning or serializing. /// /// For TopN == 0, it will be relative expensive. +/// +/// When using the natural comparator, the top N computer returns the top N elements in +/// descending order, as expected for a top N. #[derive(Serialize, Deserialize)] -#[serde(from = "TopNComputerDeser")] -pub struct TopNComputer { +#[serde(from = "TopNComputerDeser")] +pub struct TopNComputer { /// The buffer reverses sort order to get top-semantics instead of bottom-semantics - buffer: Vec>, + buffer: Vec>, top_n: usize, pub(crate) threshold: Option, -} - -impl std::fmt::Debug - for TopNComputer -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("TopNComputer") - .field("buffer_len", &self.buffer.len()) - .field("top_n", &self.top_n) - .field("current_threshold", &self.threshold) - .finish() - } + comparator: C, } // Intermediate struct for TopNComputer for deserialization, to keep vec capacity #[derive(Deserialize)] -struct TopNComputerDeser { - buffer: Vec>, +struct TopNComputerDeser { + buffer: Vec>, top_n: usize, threshold: Option, + comparator: C, } -// Custom clone to keep capacity -impl Clone - for TopNComputer -{ - fn clone(&self) -> Self { - let mut buffer_clone = Vec::with_capacity(self.buffer.capacity()); - buffer_clone.extend(self.buffer.iter().cloned()); - - TopNComputer { - buffer: buffer_clone, - top_n: self.top_n, - threshold: self.threshold.clone(), - } - } -} - -impl From> for TopNComputer { - fn from(mut value: TopNComputerDeser) -> Self { +impl From> for TopNComputer { + fn from(mut value: TopNComputerDeser) -> Self { let expected_cap = value.top_n.max(1) * 2; let current_cap = value.buffer.capacity(); if current_cap < expected_cap { @@ -966,62 +535,106 @@ impl From> for TopNCompu buffer: value.buffer, top_n: value.top_n, threshold: value.threshold, + comparator: value.comparator, } } } -impl TopNComputer +impl std::fmt::Debug for TopNComputer +where C: Comparator +{ + fn fmt(&self, f: &mut fmt::Formatter) -> std::fmt::Result { + f.debug_struct("TopNComputer") + .field("buffer_len", &self.buffer.len()) + .field("top_n", &self.top_n) + .field("current_threshold", &self.threshold) + .field("comparator", &self.comparator) + .finish() + } +} + +// Custom clone to keep capacity +impl Clone for TopNComputer { + fn clone(&self) -> Self { + let mut buffer_clone = Vec::with_capacity(self.buffer.capacity()); + buffer_clone.extend(self.buffer.iter().cloned()); + TopNComputer { + buffer: buffer_clone, + top_n: self.top_n, + threshold: self.threshold.clone(), + comparator: self.comparator.clone(), + } + } +} + +impl TopNComputer where - Score: PartialOrd + Clone, D: Ord, + TSortKey: Clone, + NaturalComparator: Comparator, { /// Create a new `TopNComputer`. /// Internally it will allocate a buffer of size `2 * top_n`. pub fn new(top_n: usize) -> Self { + TopNComputer::new_with_comparator(top_n, ReverseComparator) + } +} + +impl TopNComputer +where + D: Ord, + TSortKey: Clone, + C: Comparator, +{ + /// Create a new `TopNComputer`. + /// Internally it will allocate a buffer of size `2 * top_n`. + pub fn new_with_comparator(top_n: usize, comparator: C) -> Self { let vec_cap = top_n.max(1) * 2; TopNComputer { buffer: Vec::with_capacity(vec_cap), top_n, threshold: None, + comparator, } } /// Push a new document to the top n. /// If the document is below the current threshold, it will be ignored. #[inline] - pub fn push(&mut self, feature: Score, doc: D) { - if let Some(last_median) = self.threshold.clone() { - if !REVERSE_ORDER && feature > last_median { - return; - } - if REVERSE_ORDER && feature < last_median { + pub fn push(&mut self, sort_key: TSortKey, doc: D) { + if let Some(last_median) = &self.threshold { + if self.comparator.compare(&sort_key, last_median) == Ordering::Less { return; } } + self.append_doc(doc, sort_key); + } + + // Append a document to the top n. + // + // At this point, we need to have established that the doc is above the threshold. + #[inline(always)] + pub(crate) fn append_doc(&mut self, doc: D, sort_key: TSortKey) { if self.buffer.len() == self.buffer.capacity() { let median = self.truncate_top_n(); self.threshold = Some(median); } - - // This is faster since it avoids the buffer resizing to be inlined from vec.push() - // (this is in the hot path) - // TODO: Replace with `push_within_capacity` when it's stabilized - let uninit = self.buffer.spare_capacity_mut(); // This cannot panic, because we truncate_median will at least remove one element, since // the min capacity is 2. - uninit[0].write(ComparableDoc { doc, feature }); - // This is safe because it would panic in the line above - unsafe { - self.buffer.set_len(self.buffer.len() + 1); - } + let comparable_doc = ComparableDoc { doc, sort_key }; + push_assuming_capacity(comparable_doc, &mut self.buffer); } #[inline(never)] - fn truncate_top_n(&mut self) -> Score { + fn truncate_top_n(&mut self) -> TSortKey { // Use select_nth_unstable to find the top nth score - let (_, median_el, _) = self.buffer.select_nth_unstable(self.top_n); + let (_, median_el, _) = self.buffer.select_nth_unstable_by(self.top_n, |lhs, rhs| { + self.comparator + .compare(&rhs.sort_key, &lhs.sort_key) + .then_with(|| lhs.doc.cmp(&rhs.doc)) + }); - let median_score = median_el.feature.clone(); + let median_score = median_el.sort_key.clone(); // Remove all elements below the top_n self.buffer.truncate(self.top_n); @@ -1029,18 +642,22 @@ where } /// Returns the top n elements in sorted order. - pub fn into_sorted_vec(mut self) -> Vec> { + pub fn into_sorted_vec(mut self) -> Vec> { if self.buffer.len() > self.top_n { self.truncate_top_n(); } - self.buffer.sort_unstable(); + self.buffer.sort_unstable_by(|left, right| { + self.comparator + .compare(&right.sort_key, &left.sort_key) + .then_with(|| left.doc.cmp(&right.doc)) + }); self.buffer } /// Returns the top n elements in stored order. /// Useful if you do not need the elements in sorted order, /// for example when merging the results of multiple segments. - pub fn into_vec(mut self) -> Vec> { + pub fn into_vec(mut self) -> Vec> { if self.buffer.len() > self.top_n { self.truncate_top_n(); } @@ -1048,11 +665,28 @@ where } } +// Push an element provided there is enough capacity to do so. +// +// Panics if there is not enough capacity to add an element. +#[inline(always)] +fn push_assuming_capacity(el: T, buf: &mut Vec) { + let prev_len = buf.len(); + assert!(prev_len < buf.capacity()); + // This is mimicking the current (non-stabilized) implementation in std. + // SAFETY: we just checked we have enough capacity. + unsafe { + let end = buf.as_mut_ptr().add(prev_len); + std::ptr::write(end, el); + buf.set_len(prev_len + 1); + } +} + #[cfg(test)] mod tests { use proptest::prelude::*; use super::{TopDocs, TopNComputer}; + use crate::collector::sort_key::{ComparatorEnum, NaturalComparator, ReverseComparator}; use crate::collector::top_collector::ComparableDoc; use crate::collector::{Collector, DocSetCollector}; use crate::query::{AllQuery, Query, QueryParser}; @@ -1084,38 +718,22 @@ mod tests { crate::assert_nearly_equals!(result.0, expected.0); } } - #[test] - fn test_topn_computer_serde() { - let computer: TopNComputer = TopNComputer::new(1); - - let computer_ser = serde_json::to_string(&computer).unwrap(); - let mut computer: TopNComputer = serde_json::from_str(&computer_ser).unwrap(); - - computer.push(1u32, 5u32); - computer.push(1u32, 0u32); - computer.push(1u32, 7u32); - - assert_eq!( - computer.into_sorted_vec(), - &[ComparableDoc { - feature: 1u32, - doc: 0u32, - },] - ); - } #[test] fn test_empty_topn_computer() { - let mut computer: TopNComputer = TopNComputer::new(0); + let mut computer: TopNComputer = + TopNComputer::new_with_comparator(0, NaturalComparator); computer.push(1u32, 1u32); computer.push(1u32, 2u32); computer.push(1u32, 3u32); - assert!(computer.into_sorted_vec().is_empty()); + assert!(computer.into_vec().is_empty()); } + #[test] fn test_topn_computer() { - let mut computer: TopNComputer = TopNComputer::new(2); + let mut computer: TopNComputer = + TopNComputer::new_with_comparator(2, NaturalComparator); computer.push(1u32, 1u32); computer.push(2u32, 2u32); @@ -1126,11 +744,11 @@ mod tests { computer.into_sorted_vec(), &[ ComparableDoc { - feature: 3u32, + sort_key: 3u32, doc: 3u32, }, ComparableDoc { - feature: 2u32, + sort_key: 2u32, doc: 2u32, } ] @@ -1140,7 +758,8 @@ mod tests { #[test] fn test_topn_computer_no_panic() { for top_n in 0..10 { - let mut computer: TopNComputer = TopNComputer::new(top_n); + let mut computer: TopNComputer = + TopNComputer::new_with_comparator(top_n, NaturalComparator); for _ in 0..1 + top_n * 2 { computer.push(1u32, 1u32); @@ -1155,29 +774,11 @@ mod tests { limit in 0..10_usize, docs in proptest::collection::vec((0..100_u64, 0..100_u64), 0..100_usize), ) { - let mut computer: TopNComputer<_, _, false> = TopNComputer::new(limit); + let mut computer: TopNComputer<_, _, ReverseComparator> = TopNComputer::new_with_comparator(limit, ReverseComparator); for (feature, doc) in &docs { computer.push(*feature, *doc); } - let mut comparable_docs = docs.into_iter().map(|(feature, doc)| ComparableDoc { feature, doc }).collect::>(); - comparable_docs.sort(); - comparable_docs.truncate(limit); - prop_assert_eq!( - computer.into_sorted_vec(), - comparable_docs, - ); - } - - #[test] - fn test_topn_computer_desc_prop( - limit in 0..10_usize, - docs in proptest::collection::vec((0..100_u64, 0..100_u64), 0..100_usize), - ) { - let mut computer: TopNComputer<_, _, true> = TopNComputer::new(limit); - for (feature, doc) in &docs { - computer.push(*feature, *doc); - } - let mut comparable_docs = docs.into_iter().map(|(feature, doc)| ComparableDoc { feature, doc }).collect::>(); + let mut comparable_docs: Vec> = docs.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc }).collect::>(); comparable_docs.sort(); comparable_docs.truncate(limit); prop_assert_eq!( @@ -1196,7 +797,7 @@ mod tests { let score_docs: Vec<(Score, DocAddress)> = index .reader()? .searcher() - .search(&text_query, &TopDocs::with_limit(4))?; + .search(&text_query, &TopDocs::with_limit(4).order_by_score())?; assert_results_equals( &score_docs, &[ @@ -1218,7 +819,10 @@ mod tests { .reader() .unwrap() .searcher() - .search(&text_query, &TopDocs::with_limit(4).and_offset(2)) + .search( + &text_query, + &TopDocs::with_limit(4).and_offset(2).order_by_score(), + ) .unwrap(); assert_results_equals(&score_docs[..], &[(0.48527452, DocAddress::new(0, 0))]); } @@ -1233,7 +837,7 @@ mod tests { .reader() .unwrap() .searcher() - .search(&text_query, &TopDocs::with_limit(2)) + .search(&text_query, &TopDocs::with_limit(2).order_by_score()) .unwrap(); assert_results_equals( &score_docs, @@ -1254,7 +858,10 @@ mod tests { .reader() .unwrap() .searcher() - .search(&text_query, &TopDocs::with_limit(2).and_offset(1)) + .search( + &text_query, + &TopDocs::with_limit(2).and_offset(1).order_by_score(), + ) .unwrap(); assert_results_equals( &score_docs[..], @@ -1272,11 +879,17 @@ mod tests { // using AllQuery to get a constant score let searcher = index.reader().unwrap().searcher(); - let page_0 = searcher.search(&AllQuery, &TopDocs::with_limit(1)).unwrap(); + let page_0 = searcher + .search(&AllQuery, &TopDocs::with_limit(1).order_by_score()) + .unwrap(); - let page_1 = searcher.search(&AllQuery, &TopDocs::with_limit(2)).unwrap(); + let page_1 = searcher + .search(&AllQuery, &TopDocs::with_limit(2).order_by_score()) + .unwrap(); - let page_2 = searcher.search(&AllQuery, &TopDocs::with_limit(3)).unwrap(); + let page_2 = searcher + .search(&AllQuery, &TopDocs::with_limit(3).order_by_score()) + .unwrap(); // precondition for the test to be meaningful: we did get documents // with the same score @@ -1324,7 +937,7 @@ mod tests { let total_docs: usize = docs_per_segment.iter().sum(); // Full result set, first assert all scores are identical. let full_with_scores: Vec<(Score, DocAddress)> = searcher - .search(&AllQuery, &TopDocs::with_limit(total_docs)) + .search(&AllQuery, &TopDocs::with_limit(total_docs).order_by_score()) .unwrap(); // Sanity: at least one document was returned. prop_assert!(!full_with_scores.is_empty()); @@ -1344,7 +957,7 @@ mod tests { // 1) Increasing limit should preserve prefix ordering. for k in 1..=total_docs { let page: Vec = searcher - .search(&AllQuery, &TopDocs::with_limit(k)) + .search(&AllQuery, &TopDocs::with_limit(k).order_by_score()) .unwrap() .into_iter() .map(|(_score, addr)| addr) @@ -1362,7 +975,7 @@ mod tests { let assert_page_eq = |limit: usize| -> proptest::test_runner::TestCaseResult { let page: Vec = searcher - .search(&AllQuery, &TopDocs::with_limit(limit).and_offset(offset)) + .search(&AllQuery, &TopDocs::with_limit(limit).and_offset(offset).order_by_score()) .unwrap() .into_iter() .map(|(_score, addr)| addr) @@ -1386,7 +999,7 @@ mod tests { while offset < total_docs { let size = page_size.min(total_docs - offset); let page: Vec = searcher - .search(&AllQuery, &TopDocs::with_limit(size).and_offset(offset)) + .search(&AllQuery, &TopDocs::with_limit(size).and_offset(offset).order_by_score()) .unwrap() .into_iter() .map(|(_score, addr)| addr) @@ -1438,7 +1051,7 @@ mod tests { // Full result set, first assert all scores are identical across docs. let full_with_scores: Vec<(Score, DocAddress)> = searcher - .search(&tq, &TopDocs::with_limit(total_docs)) + .search(&tq, &TopDocs::with_limit(total_docs).order_by_score()) .unwrap(); // Sanity: at least one document was returned. prop_assert!(!full_with_scores.is_empty()); @@ -1458,7 +1071,7 @@ mod tests { // 1) Increasing limit should preserve prefix ordering. for k in 1..=total_docs { let page: Vec = searcher - .search(&tq, &TopDocs::with_limit(k)) + .search(&tq, &TopDocs::with_limit(k).order_by_score()) .unwrap() .into_iter() .map(|(_score, addr)| addr) @@ -1473,7 +1086,7 @@ mod tests { let assert_page_eq = |limit: usize| -> proptest::test_runner::TestCaseResult { let page: Vec = searcher - .search(&tq, &TopDocs::with_limit(limit).and_offset(offset)) + .search(&tq, &TopDocs::with_limit(limit).and_offset(offset).order_by_score()) .unwrap() .into_iter() .map(|(_score, addr)| addr) @@ -1494,7 +1107,7 @@ mod tests { while offset < total_docs { let size = page_size.min(total_docs - offset); let page: Vec = searcher - .search(&tq, &TopDocs::with_limit(size).and_offset(offset)) + .search(&tq, &TopDocs::with_limit(size).and_offset(offset).order_by_score()) .unwrap() .into_iter() .map(|(_score, addr)| addr) @@ -1545,13 +1158,13 @@ mod tests { let searcher = index.reader()?.searcher(); let top_collector = TopDocs::with_limit(4).order_by_u64_field(SIZE, Order::Desc); - let top_docs: Vec<(u64, DocAddress)> = searcher.search(&query, &top_collector)?; + let top_docs: Vec<(Option, DocAddress)> = searcher.search(&query, &top_collector)?; assert_eq!( &top_docs[..], &[ - (64, DocAddress::new(0, 1)), - (16, DocAddress::new(0, 2)), - (12, DocAddress::new(0, 0)) + (Some(64), DocAddress::new(0, 1)), + (Some(16), DocAddress::new(0, 2)), + (Some(12), DocAddress::new(0, 0)) ] ); Ok(()) @@ -1584,12 +1197,13 @@ mod tests { index_writer.commit()?; let searcher = index.reader()?.searcher(); let top_collector = TopDocs::with_limit(3).order_by_fast_field("birthday", Order::Desc); - let top_docs: Vec<(DateTime, DocAddress)> = searcher.search(&AllQuery, &top_collector)?; + let top_docs: Vec<(Option, DocAddress)> = + searcher.search(&AllQuery, &top_collector)?; assert_eq!( &top_docs[..], &[ - (mr_birthday, DocAddress::new(0, 1)), - (pr_birthday, DocAddress::new(0, 0)), + (Some(mr_birthday), DocAddress::new(0, 1)), + (Some(pr_birthday), DocAddress::new(0, 0)), ] ); Ok(()) @@ -1614,12 +1228,13 @@ mod tests { index_writer.commit()?; let searcher = index.reader()?.searcher(); let top_collector = TopDocs::with_limit(3).order_by_fast_field("altitude", Order::Desc); - let top_docs: Vec<(i64, DocAddress)> = searcher.search(&AllQuery, &top_collector)?; + let top_docs: Vec<(Option, DocAddress)> = + searcher.search(&AllQuery, &top_collector)?; assert_eq!( &top_docs[..], &[ - (40i64, DocAddress::new(0, 1)), - (-1i64, DocAddress::new(0, 0)), + (Some(40i64), DocAddress::new(0, 1)), + (Some(-1i64), DocAddress::new(0, 0)), ] ); Ok(()) @@ -1644,12 +1259,13 @@ mod tests { index_writer.commit()?; let searcher = index.reader()?.searcher(); let top_collector = TopDocs::with_limit(3).order_by_fast_field("altitude", Order::Desc); - let top_docs: Vec<(f64, DocAddress)> = searcher.search(&AllQuery, &top_collector)?; + let top_docs: Vec<(Option, DocAddress)> = + searcher.search(&AllQuery, &top_collector)?; assert_eq!( &top_docs[..], &[ - (40f64, DocAddress::new(0, 1)), - (-1.0f64, DocAddress::new(0, 0)), + (Some(40f64), DocAddress::new(0, 1)), + (Some(-1.0f64), DocAddress::new(0, 0)), ] ); Ok(()) @@ -1678,7 +1294,7 @@ mod tests { order: Order, limit: usize, offset: usize, - ) -> crate::Result> { + ) -> crate::Result, DocAddress)>> { let searcher = index.reader()?.searcher(); let top_collector = TopDocs::with_limit(limit) .and_offset(offset) @@ -1689,17 +1305,17 @@ mod tests { assert_eq!( &query(&index, Order::Desc, 3, 0)?, &[ - ("tokyo".to_owned(), DocAddress::new(0, 2)), - ("greenville".to_owned(), DocAddress::new(0, 1)), - ("austin".to_owned(), DocAddress::new(0, 0)), + (Some("tokyo".to_owned()), DocAddress::new(0, 2)), + (Some("greenville".to_owned()), DocAddress::new(0, 1)), + (Some("austin".to_owned()), DocAddress::new(0, 0)), ] ); assert_eq!( &query(&index, Order::Desc, 2, 0)?, &[ - ("tokyo".to_owned(), DocAddress::new(0, 2)), - ("greenville".to_owned(), DocAddress::new(0, 1)), + (Some("tokyo".to_owned()), DocAddress::new(0, 2)), + (Some("greenville".to_owned()), DocAddress::new(0, 1)), ] ); @@ -1708,33 +1324,33 @@ mod tests { assert_eq!( &query(&index, Order::Desc, 2, 1)?, &[ - ("greenville".to_owned(), DocAddress::new(0, 1)), - ("austin".to_owned(), DocAddress::new(0, 0)), + (Some("greenville".to_owned()), DocAddress::new(0, 1)), + (Some("austin".to_owned()), DocAddress::new(0, 0)), ] ); assert_eq!( &query(&index, Order::Asc, 3, 0)?, &[ - ("austin".to_owned(), DocAddress::new(0, 0)), - ("greenville".to_owned(), DocAddress::new(0, 1)), - ("tokyo".to_owned(), DocAddress::new(0, 2)), + (Some("austin".to_owned()), DocAddress::new(0, 0)), + (Some("greenville".to_owned()), DocAddress::new(0, 1)), + (Some("tokyo".to_owned()), DocAddress::new(0, 2)), ] ); assert_eq!( &query(&index, Order::Asc, 2, 1)?, &[ - ("greenville".to_owned(), DocAddress::new(0, 1)), - ("tokyo".to_owned(), DocAddress::new(0, 2)), + (Some("greenville".to_owned()), DocAddress::new(0, 1)), + (Some("tokyo".to_owned()), DocAddress::new(0, 2)), ] ); assert_eq!( &query(&index, Order::Asc, 2, 0)?, &[ - ("austin".to_owned(), DocAddress::new(0, 0)), - ("greenville".to_owned(), DocAddress::new(0, 1)), + (Some("austin".to_owned()), DocAddress::new(0, 0)), + (Some("greenville".to_owned()), DocAddress::new(0, 1)), ] ); @@ -1776,7 +1392,7 @@ mod tests { let searcher = index.reader()?.searcher(); let top_n_results = searcher.search(&AllQuery, &TopDocs::with_limit(limit) .and_offset(offset) - .order_by_string_fast_field("city", order.clone()))?; + .order_by_string_fast_field("city", order))?; let all_results = searcher.search(&AllQuery, &DocSetCollector)?.into_iter().map(|doc_address| { // Get the term for this address. // NOTE: We can't determine the SegmentIds that will be generated for Segments @@ -1785,21 +1401,21 @@ mod tests { let term_ord = column.term_ords(doc_address.doc_id).next().unwrap(); let mut city = Vec::new(); column.dictionary().ord_to_term(term_ord, &mut city).unwrap(); - (String::try_from(city).unwrap(), doc_address) + (Some(String::try_from(city).unwrap()), doc_address) }); // Using the TopDocs collector should always be equivalent to sorting, skipping the // offset, and then taking the limit. let sorted_docs: Vec<_> = if order.is_desc() { let mut comparable_docs: Vec> = - all_results.into_iter().map(|(feature, doc)| ComparableDoc { feature, doc}).collect(); + all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect(); comparable_docs.sort(); - comparable_docs.into_iter().map(|cd| (cd.feature, cd.doc)).collect() + comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect() } else { let mut comparable_docs: Vec> = - all_results.into_iter().map(|(feature, doc)| ComparableDoc { feature, doc}).collect(); + all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect(); comparable_docs.sort(); - comparable_docs.into_iter().map(|cd| (cd.feature, cd.doc)).collect() + comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect() }; let expected_docs = sorted_docs.into_iter().skip(offset).take(limit).collect::>(); prop_assert_eq!( @@ -1850,33 +1466,26 @@ mod tests { } #[test] - fn test_field_wrong_type() -> crate::Result<()> { + fn test_field_wrong_type() { let mut schema_builder = Schema::builder(); - let size = schema_builder.add_u64_field(SIZE, STORED); + let _size = schema_builder.add_u64_field(SIZE, STORED); let schema = schema_builder.build(); - let index = Index::create_in_ram(schema); - let mut index_writer = index.writer_for_tests()?; - index_writer.add_document(doc!(size=>1u64))?; - index_writer.commit()?; - let searcher = index.reader()?.searcher(); - let segment = searcher.segment_reader(0); let top_collector = TopDocs::with_limit(4).order_by_fast_field::(SIZE, Order::Desc); - let err = top_collector.for_segment(0, segment).err().unwrap(); + let err = top_collector.check_schema(&schema).err().unwrap(); assert!( - matches!(err, crate::TantivyError::SchemaError(msg) if msg == "Field \"size\" is not a fast field.") + matches!(err, crate::TantivyError::SchemaError(msg) if msg == "Field `size` is not a fast field.") ); - Ok(()) } #[test] - fn test_tweak_score_top_collector_with_offset() -> crate::Result<()> { + fn test_sort_key_top_collector_with_offset() -> crate::Result<()> { let index = make_index()?; let field = index.schema().get_field("text").unwrap(); let query_parser = QueryParser::for_index(&index, vec![field]); let text_query = query_parser.parse_query("droopy tax")?; - let collector = TopDocs::with_limit(2).and_offset(1).tweak_score( - move |_segment_reader: &SegmentReader| move |doc: DocId, _original_score: Score| doc, - ); + let collector = TopDocs::with_limit(2) + .and_offset(1) + .order_by(move |_segment_reader: &SegmentReader| move |doc: DocId| doc); let score_docs: Vec<(u32, DocAddress)> = index.reader()?.searcher().search(&text_query, &collector)?; assert_eq!( @@ -1894,7 +1503,7 @@ mod tests { let text_query = query_parser.parse_query("droopy tax").unwrap(); let collector = TopDocs::with_limit(2) .and_offset(1) - .custom_score(move |_segment_reader: &SegmentReader| move |doc: DocId| doc); + .order_by(move |_segment_reader: &SegmentReader| move |doc: DocId| doc); let score_docs: Vec<(u32, DocAddress)> = index .reader() .unwrap() @@ -1956,22 +1565,23 @@ mod tests { let searcher = index.reader()?.searcher(); let top_collector = TopDocs::with_limit(4).order_by_fast_field(SIZE, Order::Asc); - let top_docs: Vec<(u64, DocAddress)> = searcher.search(&query, &top_collector)?; + let top_docs: Vec<(Option, DocAddress)> = searcher.search(&query, &top_collector)?; assert_eq!( &top_docs[..], &[ - (12, DocAddress::new(0, 0)), - (16, DocAddress::new(0, 2)), - (64, DocAddress::new(0, 1)), - (18446744073709551615, DocAddress::new(0, 3)), + (Some(12), DocAddress::new(0, 0)), + (Some(16), DocAddress::new(0, 2)), + (Some(64), DocAddress::new(0, 1)), + (None, DocAddress::new(0, 3)), ] ); Ok(()) } #[test] - fn test_topn_computer_asc() { - let mut computer: TopNComputer = TopNComputer::new(2); + fn test_topn_computer_desc() { + let mut computer: TopNComputer = + TopNComputer::new_with_comparator(2, ComparatorEnum::from(Order::Desc)); computer.push(1u32, 1u32); computer.push(2u32, 2u32); @@ -1983,14 +1593,204 @@ mod tests { computer.into_sorted_vec(), &[ ComparableDoc { - feature: 1u32, + sort_key: 4u32, + doc: 5u32, + }, + ComparableDoc { + sort_key: 3u32, + doc: 3u32, + } + ] + ); + } + + #[test] + fn test_topn_computer_asc() { + let mut computer: TopNComputer = + TopNComputer::new_with_comparator(2, ComparatorEnum::from(Order::Asc)); + computer.push(1u32, 1u32); + computer.push(2u32, 2u32); + computer.push(3u32, 3u32); + computer.push(2u32, 4u32); + computer.push(4u32, 5u32); + computer.push(1u32, 6u32); + assert_eq!( + computer.into_sorted_vec(), + &[ + ComparableDoc { + sort_key: 1u32, doc: 1u32, }, ComparableDoc { - feature: 1u32, + sort_key: 1u32, doc: 6u32, } ] ); } + + #[test] + fn test_topn_computer_option_asc_null_at_the_end() { + let mut computer: TopNComputer, u32, _> = + TopNComputer::new_with_comparator(2, ComparatorEnum::ReverseNoneLower); + computer.push(Some(1u32), 1u32); + computer.push(Some(2u32), 2u32); + computer.push(None, 3u32); + assert_eq!( + computer.into_sorted_vec(), + &[ + ComparableDoc { + sort_key: Some(1u32), + doc: 1u32, + }, + ComparableDoc { + sort_key: Some(2u32), + doc: 2u32, + } + ] + ); + } + + #[test] + fn test_topn_computer_option_asc_null_at_the_begining() { + let mut computer: TopNComputer, u32, _> = + TopNComputer::new_with_comparator(2, ComparatorEnum::Reverse); + computer.push(Some(1u32), 1u32); + computer.push(Some(2u32), 2u32); + computer.push(None, 3u32); + assert_eq!( + computer.into_sorted_vec(), + &[ + ComparableDoc { + sort_key: None, + doc: 3u32, + }, + ComparableDoc { + sort_key: Some(1u32), + doc: 1u32, + }, + ] + ); + } + + #[test] + fn test_push_assuming_capacity() { + let mut vec = Vec::with_capacity(2); + super::push_assuming_capacity(1, &mut vec); + assert_eq!(&vec, &[1]); + super::push_assuming_capacity(2, &mut vec); + assert_eq!(&vec, &[1, 2]); + } + + #[test] + #[should_panic] + fn test_push_assuming_capacity_panics_when_no_cap() { + let mut vec = Vec::with_capacity(1); + super::push_assuming_capacity(1, &mut vec); + assert_eq!(&vec, &[1]); + super::push_assuming_capacity(2, &mut vec); + } + + #[test] + fn test_top_n_computer_not_at_capacity() { + let mut top_n_computer = TopNComputer::new_with_comparator(4, NaturalComparator); + top_n_computer.append_doc(1, 0.8); + top_n_computer.append_doc(3, 0.2); + top_n_computer.append_doc(5, 0.3); + assert_eq!( + &top_n_computer.into_sorted_vec(), + &[ + ComparableDoc { + sort_key: 0.8, + doc: 1 + }, + ComparableDoc { + sort_key: 0.3, + doc: 5 + }, + ComparableDoc { + sort_key: 0.2, + doc: 3 + }, + ] + ); + } + + #[test] + fn test_top_n_computer_at_capacity() { + let mut top_collector = TopNComputer::new_with_comparator(4, NaturalComparator); + top_collector.append_doc(1, 0.8); + top_collector.append_doc(3, 0.2); + top_collector.append_doc(5, 0.3); + top_collector.append_doc(7, 0.9); + top_collector.append_doc(9, -0.2); + assert_eq!( + &top_collector.into_sorted_vec(), + &[ + ComparableDoc { + sort_key: 0.9, + doc: 7 + }, + ComparableDoc { + sort_key: 0.8, + doc: 1 + }, + ComparableDoc { + sort_key: 0.3, + doc: 5 + }, + ComparableDoc { + sort_key: 0.2, + doc: 3 + }, + ] + ); + } + + #[test] + fn test_top_segment_collector_stable_ordering_for_equal_feature() { + // given that the documents are collected in ascending doc id order, + // when harvesting we have to guarantee stable sorting in case of a tie + // on the score + let doc_ids_collection = [4, 5, 6]; + let score = 3.3f32; + + let mut top_collector_limit_2 = TopNComputer::new_with_comparator(2, NaturalComparator); + for id in &doc_ids_collection { + top_collector_limit_2.append_doc(*id, score); + } + + let mut top_collector_limit_3 = TopNComputer::new_with_comparator(3, NaturalComparator); + for id in &doc_ids_collection { + top_collector_limit_3.append_doc(*id, score); + } + + let docs_limit_2 = top_collector_limit_2.into_sorted_vec(); + let docs_limit_3 = top_collector_limit_3.into_sorted_vec(); + + assert_eq!(&docs_limit_2, &docs_limit_3[..2],); + } +} + +#[cfg(all(test, feature = "unstable"))] +mod bench { + use test::Bencher; + + use super::TopNComputer; + use crate::collector::sort_key::NaturalComparator; + + #[bench] + fn bench_top_segment_collector_collect_at_capacity(b: &mut Bencher) { + let mut top_collector = TopNComputer::new_with_comparator(100, NaturalComparator); + + for i in 0..100 { + top_collector.append_doc(i, 0.8); + } + + b.iter(|| { + for i in 0..100 { + top_collector.append_doc(i, 0.8); + } + }); + } } diff --git a/src/collector/tweak_score_top_collector.rs b/src/collector/tweak_score_top_collector.rs deleted file mode 100644 index e7e8d1547..000000000 --- a/src/collector/tweak_score_top_collector.rs +++ /dev/null @@ -1,124 +0,0 @@ -use crate::collector::top_collector::{TopCollector, TopSegmentCollector}; -use crate::collector::{Collector, SegmentCollector}; -use crate::{DocAddress, DocId, Result, Score, SegmentReader}; - -pub(crate) struct TweakedScoreTopCollector { - score_tweaker: TScoreTweaker, - collector: TopCollector, -} - -impl TweakedScoreTopCollector -where TScore: Clone + PartialOrd -{ - pub fn new( - score_tweaker: TScoreTweaker, - collector: TopCollector, - ) -> TweakedScoreTopCollector { - TweakedScoreTopCollector { - score_tweaker, - collector, - } - } -} - -/// A `ScoreSegmentTweaker` makes it possible to modify the default score -/// for a given document belonging to a specific segment. -/// -/// It is the segment local version of the [`ScoreTweaker`]. -pub trait ScoreSegmentTweaker: 'static { - /// Tweak the given `score` for the document `doc`. - fn score(&mut self, doc: DocId, score: Score) -> TScore; -} - -/// `ScoreTweaker` makes it possible to tweak the score -/// emitted by the scorer into another one. -/// -/// The `ScoreTweaker` itself does not make much of the computation itself. -/// Instead, it helps constructing `Self::Child` instances that will compute -/// the score at a segment scale. -pub trait ScoreTweaker: Sync { - /// Type of the associated [`ScoreSegmentTweaker`]. - type Child: ScoreSegmentTweaker; - - /// Builds a child tweaker for a specific segment. The child scorer is associated with - /// a specific segment. - fn segment_tweaker(&self, segment_reader: &SegmentReader) -> Result; -} - -impl Collector for TweakedScoreTopCollector -where - TScoreTweaker: ScoreTweaker + Send + Sync, - TScore: 'static + PartialOrd + Clone + Send + Sync, -{ - type Fruit = Vec<(TScore, DocAddress)>; - - type Child = TopTweakedScoreSegmentCollector; - - fn for_segment( - &self, - segment_local_id: u32, - segment_reader: &SegmentReader, - ) -> Result { - let segment_scorer = self.score_tweaker.segment_tweaker(segment_reader)?; - let segment_collector = self.collector.for_segment(segment_local_id, segment_reader); - Ok(TopTweakedScoreSegmentCollector { - segment_collector, - segment_scorer, - }) - } - - fn requires_scoring(&self) -> bool { - true - } - - fn merge_fruits(&self, segment_fruits: Vec) -> Result { - self.collector.merge_fruits(segment_fruits) - } -} - -pub struct TopTweakedScoreSegmentCollector -where - TScore: 'static + PartialOrd + Clone + Send + Sync + Sized, - TSegmentScoreTweaker: ScoreSegmentTweaker, -{ - segment_collector: TopSegmentCollector, - segment_scorer: TSegmentScoreTweaker, -} - -impl SegmentCollector - for TopTweakedScoreSegmentCollector -where - TScore: 'static + PartialOrd + Clone + Send + Sync, - TSegmentScoreTweaker: 'static + ScoreSegmentTweaker, -{ - type Fruit = Vec<(TScore, DocAddress)>; - - fn collect(&mut self, doc: DocId, score: Score) { - let score = self.segment_scorer.score(doc, score); - self.segment_collector.collect(doc, score); - } - - fn harvest(self) -> Vec<(TScore, DocAddress)> { - self.segment_collector.harvest() - } -} - -impl ScoreTweaker for F -where - F: 'static + Send + Sync + Fn(&SegmentReader) -> TSegmentScoreTweaker, - TSegmentScoreTweaker: ScoreSegmentTweaker, -{ - type Child = TSegmentScoreTweaker; - - fn segment_tweaker(&self, segment_reader: &SegmentReader) -> Result { - Ok((self)(segment_reader)) - } -} - -impl ScoreSegmentTweaker for F -where F: 'static + FnMut(DocId, Score) -> TScore -{ - fn score(&mut self, doc: DocId, score: Score) -> TScore { - (self)(doc, score) - } -} diff --git a/src/compat_tests.rs b/src/compat_tests.rs index ac1d00a45..e9af8803c 100644 --- a/src/compat_tests.rs +++ b/src/compat_tests.rs @@ -69,7 +69,7 @@ fn assert_date_time_precision(index: &Index, doc_store_precision: DateTimePrecis .parse_query("dateformat") .expect("Failed to parse query"); let top_docs = searcher - .search(&query, &TopDocs::with_limit(1)) + .search(&query, &TopDocs::with_limit(1).order_by_score()) .expect("Search failed"); assert_eq!(top_docs.len(), 1, "Expected 1 search result"); diff --git a/src/core/searcher.rs b/src/core/searcher.rs index 51db7311f..9603d0f4f 100644 --- a/src/core/searcher.rs +++ b/src/core/searcher.rs @@ -225,6 +225,7 @@ impl Searcher { enabled_scoring: EnableScoring, ) -> crate::Result { let weight = query.weight(enabled_scoring)?; + collector.check_schema(self.schema())?; let segment_readers = self.segment_readers(); let fruits = executor.map( |(segment_ord, segment_reader)| { diff --git a/src/index/index_meta.rs b/src/index/index_meta.rs index 0962bd9bc..86eaa35d6 100644 --- a/src/index/index_meta.rs +++ b/src/index/index_meta.rs @@ -276,13 +276,14 @@ impl Default for IndexSettings { } /// The order to sort by -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +#[derive(Clone, Copy, Debug, Serialize, Deserialize, Eq, PartialEq)] pub enum Order { /// Ascending Order Asc, /// Descending Order Desc, } + impl Order { /// return if the Order is ascending pub fn is_asc(&self) -> bool { diff --git a/src/indexer/index_writer.rs b/src/indexer/index_writer.rs index 7f993027e..1ba92d6de 100644 --- a/src/indexer/index_writer.rs +++ b/src/indexer/index_writer.rs @@ -513,7 +513,7 @@ impl IndexWriter { /// let searcher = index.reader()?.searcher(); /// let query_parser = QueryParser::for_index(&index, vec![title]); /// let query_promo = query_parser.parse_query("Prometheus")?; - /// let top_docs_promo = searcher.search(&query_promo, &TopDocs::with_limit(1))?; + /// let top_docs_promo = searcher.search(&query_promo, &TopDocs::with_limit(1).order_by_score())?; /// /// assert!(top_docs_promo.is_empty()); /// Ok(()) @@ -946,11 +946,11 @@ mod tests { let searcher = reader.searcher(); let a_docs = searcher - .search(&a_query, &TopDocs::with_limit(1)) + .search(&a_query, &TopDocs::with_limit(1).order_by_score()) .expect("search for a failed"); let b_docs = searcher - .search(&b_query, &TopDocs::with_limit(1)) + .search(&b_query, &TopDocs::with_limit(1).order_by_score()) .expect("search for b failed"); assert_eq!(a_docs.len(), 1); @@ -2014,8 +2014,9 @@ mod tests { let query = QueryParser::for_index(&index, vec![field]) .parse_query(term) .unwrap(); - let top_docs: Vec<(f32, DocAddress)> = - searcher.search(&query, &TopDocs::with_limit(1000)).unwrap(); + let top_docs: Vec<(f32, DocAddress)> = searcher + .search(&query, &TopDocs::with_limit(1000).order_by_score()) + .unwrap(); top_docs.iter().map(|el| el.1).collect::>() }; @@ -2449,8 +2450,9 @@ mod tests { Term::from_field_u64(id_field, existing_id), IndexRecordOption::Basic, ); - let top_docs: Vec<(f32, DocAddress)> = - searcher.search(&query, &TopDocs::with_limit(10)).unwrap(); + let top_docs: Vec<(f32, DocAddress)> = searcher + .search(&query, &TopDocs::with_limit(10).order_by_score()) + .unwrap(); assert_eq!(top_docs.len(), 1); // Was failing @@ -2491,8 +2493,9 @@ mod tests { Term::from_field_i64(id_field, 10i64), IndexRecordOption::Basic, ); - let top_docs: Vec<(f32, DocAddress)> = - searcher.search(&query, &TopDocs::with_limit(10)).unwrap(); + let top_docs: Vec<(f32, DocAddress)> = searcher + .search(&query, &TopDocs::with_limit(10).order_by_score()) + .unwrap(); assert_eq!(top_docs.len(), 1); // Fails @@ -2500,8 +2503,9 @@ mod tests { Term::from_field_i64(id_field, 30i64), IndexRecordOption::Basic, ); - let top_docs: Vec<(f32, DocAddress)> = - searcher.search(&query, &TopDocs::with_limit(10)).unwrap(); + let top_docs: Vec<(f32, DocAddress)> = searcher + .search(&query, &TopDocs::with_limit(10).order_by_score()) + .unwrap(); assert_eq!(top_docs.len(), 1); // Fails diff --git a/src/indexer/merge_index_test.rs b/src/indexer/merge_index_test.rs index 8b8dec3ae..43f80a9d0 100644 --- a/src/indexer/merge_index_test.rs +++ b/src/indexer/merge_index_test.rs @@ -104,8 +104,9 @@ mod tests { let query = QueryParser::for_index(&index, vec![my_text_field]) .parse_query(term) .unwrap(); - let top_docs: Vec<(f32, DocAddress)> = - searcher.search(&query, &TopDocs::with_limit(3)).unwrap(); + let top_docs: Vec<(f32, DocAddress)> = searcher + .search(&query, &TopDocs::with_limit(3).order_by_score()) + .unwrap(); top_docs.iter().map(|el| el.1.doc_id).collect::>() }; diff --git a/src/indexer/mod.rs b/src/indexer/mod.rs index 687a6b119..2d86aa461 100644 --- a/src/indexer/mod.rs +++ b/src/indexer/mod.rs @@ -589,7 +589,9 @@ mod tests_mmap { }; let query_str = &format!("{}:{}", indexed_field.field_name, val); let query = query_parser.parse_query(query_str).unwrap(); - let count_docs = searcher.search(&*query, &TopDocs::with_limit(2)).unwrap(); + let count_docs = searcher + .search(&*query, &TopDocs::with_limit(2).order_by_score()) + .unwrap(); if indexed_field.field_name.contains("empty") || indexed_field.typ == Type::Json { assert_eq!(count_docs.len(), 0); } else { @@ -661,7 +663,9 @@ mod tests_mmap { for (indexed_field, val) in fields_and_vals.iter() { let query_str = &format!("{indexed_field}:{val}"); let query = query_parser.parse_query(query_str).unwrap(); - let count_docs = searcher.search(&*query, &TopDocs::with_limit(2)).unwrap(); + let count_docs = searcher + .search(&*query, &TopDocs::with_limit(2).order_by_score()) + .unwrap(); assert!(!count_docs.is_empty(), "{indexed_field}:{val}"); } // Test if field name can be used for aggregation diff --git a/src/indexer/segment_updater.rs b/src/indexer/segment_updater.rs index b72667ded..bc941d1ed 100644 --- a/src/indexer/segment_updater.rs +++ b/src/indexer/segment_updater.rs @@ -1052,8 +1052,9 @@ mod tests { let query = QueryParser::for_index(&index, vec![text_field]) .parse_query(term) .unwrap(); - let top_docs: Vec<(f32, DocAddress)> = - searcher.search(&query, &TopDocs::with_limit(3)).unwrap(); + let top_docs: Vec<(f32, DocAddress)> = searcher + .search(&query, &TopDocs::with_limit(3).order_by_score()) + .unwrap(); top_docs.iter().map(|el| el.1.doc_id).collect::>() }; diff --git a/src/indexer/segment_writer.rs b/src/indexer/segment_writer.rs index cfeea1177..30338187c 100644 --- a/src/indexer/segment_writer.rs +++ b/src/indexer/segment_writer.rs @@ -520,7 +520,7 @@ mod tests { .reader() .unwrap() .searcher() - .search(&text_query, &TopDocs::with_limit(4)) + .search(&text_query, &TopDocs::with_limit(4).order_by_score()) .unwrap(); assert_eq!(score_docs.len(), 1); @@ -529,7 +529,7 @@ mod tests { .reader() .unwrap() .searcher() - .search(&text_query, &TopDocs::with_limit(4)) + .search(&text_query, &TopDocs::with_limit(4).order_by_score()) .unwrap(); assert_eq!(score_docs.len(), 2); } @@ -562,7 +562,7 @@ mod tests { .reader() .unwrap() .searcher() - .search(&text_query, &TopDocs::with_limit(4)) + .search(&text_query, &TopDocs::with_limit(4).order_by_score()) .unwrap(); assert_eq!(score_docs.len(), 1); }; diff --git a/src/lib.rs b/src/lib.rs index 8077565c7..1027f4f46 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -85,7 +85,7 @@ //! // Perform search. //! // `topdocs` contains the 10 most relevant doc ids, sorted by decreasing scores... //! let top_docs: Vec<(Score, DocAddress)> = -//! searcher.search(&query, &TopDocs::with_limit(10))?; +//! searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?; //! //! for (_score, doc_address) in top_docs { //! // Retrieve the actual content of documents given its `doc_address`. diff --git a/src/query/boolean_query/mod.rs b/src/query/boolean_query/mod.rs index cacd30e57..0ddc5a26c 100644 --- a/src/query/boolean_query/mod.rs +++ b/src/query/boolean_query/mod.rs @@ -182,7 +182,7 @@ mod tests { let matching_topdocs = |query: &dyn Query| { reader .searcher() - .search(query, &TopDocs::with_limit(3)) + .search(query, &TopDocs::with_limit(3).order_by_score()) .unwrap() }; diff --git a/src/query/disjunction_max_query.rs b/src/query/disjunction_max_query.rs index e7d420316..cb62b6fd3 100644 --- a/src/query/disjunction_max_query.rs +++ b/src/query/disjunction_max_query.rs @@ -53,7 +53,7 @@ use crate::{Score, Term}; /// // TermQuery "diary" and "girl" should be present and only one should be accounted in score /// let queries1 = vec![diary_term_query.box_clone(), girl_term_query.box_clone()]; /// let diary_and_girl = DisjunctionMaxQuery::new(queries1); -/// let documents = searcher.search(&diary_and_girl, &TopDocs::with_limit(3))?; +/// let documents = searcher.search(&diary_and_girl, &TopDocs::with_limit(3).order_by_score())?; /// assert_eq!(documents[0].0, documents[1].0); /// assert_eq!(documents[1].0, documents[2].0); /// @@ -62,7 +62,7 @@ use crate::{Score, Term}; /// let queries2 = vec![diary_term_query.box_clone(), girl_term_query.box_clone()]; /// let tie_breaker = 0.7; /// let diary_and_girl_with_tie_breaker = DisjunctionMaxQuery::with_tie_breaker(queries2, tie_breaker); -/// let documents = searcher.search(&diary_and_girl_with_tie_breaker, &TopDocs::with_limit(3))?; +/// let documents = searcher.search(&diary_and_girl_with_tie_breaker, &TopDocs::with_limit(3).order_by_score())?; /// assert_eq!(documents[1].0, documents[2].0); /// // For this test all terms brings the same score. So we can do easy math and assume that /// // `DisjunctionMaxQuery` with tie breakers score should be equal diff --git a/src/query/fuzzy_query.rs b/src/query/fuzzy_query.rs index 143eed1c7..a0634b96b 100644 --- a/src/query/fuzzy_query.rs +++ b/src/query/fuzzy_query.rs @@ -67,7 +67,7 @@ impl Automaton for DfaWrapper { /// { /// let term = Term::from_field_text(title, "Diary"); /// let query = FuzzyTermQuery::new(term, 1, true); -/// let (top_docs, count) = searcher.search(&query, &(TopDocs::with_limit(2), Count)).unwrap(); +/// let (top_docs, count) = searcher.search(&query, &(TopDocs::with_limit(2).order_by_score(), Count)).unwrap(); /// assert_eq!(count, 2); /// assert_eq!(top_docs.len(), 2); /// } @@ -241,7 +241,8 @@ mod test { { let term = get_json_path_term("attributes.aa:japan")?; let fuzzy_query = FuzzyTermQuery::new(term, 2, true); - let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?; + let top_docs = + searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?; assert_eq!(top_docs.len(), 1, "Expected only 1 document"); assert_eq!(top_docs[0].1.doc_id, 1, "Expected the second document"); } @@ -252,7 +253,8 @@ mod test { let term = get_json_path_term("attributes.a:japon")?; let fuzzy_query = FuzzyTermQuery::new(term, 1, true); - let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?; + let top_docs = + searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?; assert_eq!(top_docs.len(), 1, "Expected only 1 document"); assert_eq!(top_docs[0].1.doc_id, 0, "Expected the first document"); } @@ -262,7 +264,8 @@ mod test { let term = get_json_path_term("attributes.a:jap")?; let fuzzy_query = FuzzyTermQuery::new(term, 1, true); - let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?; + let top_docs = + searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?; assert_eq!(top_docs.len(), 0, "Expected no document"); } @@ -292,7 +295,8 @@ mod test { { let term = Term::from_field_text(country_field, "japon"); let fuzzy_query = FuzzyTermQuery::new(term, 1, true); - let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?; + let top_docs = + searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?; assert_eq!(top_docs.len(), 1, "Expected only 1 document"); let (score, _) = top_docs[0]; assert_nearly_equals!(1.0, score); @@ -303,7 +307,8 @@ mod test { let term = Term::from_field_text(country_field, "jap"); let fuzzy_query = FuzzyTermQuery::new(term, 1, true); - let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?; + let top_docs = + searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?; assert_eq!(top_docs.len(), 0, "Expected no document"); } @@ -311,7 +316,8 @@ mod test { { let term = Term::from_field_text(country_field, "jap"); let fuzzy_query = FuzzyTermQuery::new_prefix(term, 1, true); - let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?; + let top_docs = + searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?; assert_eq!(top_docs.len(), 1, "Expected only 1 document"); let (score, _) = top_docs[0]; assert_nearly_equals!(1.0, score); diff --git a/src/query/more_like_this/query.rs b/src/query/more_like_this/query.rs index dd3db39da..e48e61533 100644 --- a/src/query/more_like_this/query.rs +++ b/src/query/more_like_this/query.rs @@ -267,7 +267,7 @@ mod tests { .with_boost_factor(1.0) .with_stop_words(vec!["old".to_string()]) .with_document(DocAddress::new(0, 0)); - let top_docs = searcher.search(&query, &TopDocs::with_limit(5))?; + let top_docs = searcher.search(&query, &TopDocs::with_limit(5).order_by_score())?; let mut doc_ids: Vec<_> = top_docs.iter().map(|item| item.1.doc_id).collect(); doc_ids.sort_unstable(); @@ -283,7 +283,7 @@ mod tests { .with_max_word_length(5) .with_boost_factor(1.0) .with_document(DocAddress::new(0, 4)); - let top_docs = searcher.search(&query, &TopDocs::with_limit(5))?; + let top_docs = searcher.search(&query, &TopDocs::with_limit(5).order_by_score())?; let mut doc_ids: Vec<_> = top_docs.iter().map(|item| item.1.doc_id).collect(); doc_ids.sort_unstable(); diff --git a/src/query/range_query/range_query.rs b/src/query/range_query/range_query.rs index 5fe7f03ec..1893a06a5 100644 --- a/src/query/range_query/range_query.rs +++ b/src/query/range_query/range_query.rs @@ -496,7 +496,7 @@ mod tests { let searcher = reader.searcher(); let query_parser = QueryParser::for_index(&index, vec![title]); let query = query_parser.parse_query("hemoglobin AND year:[1970 TO 1990]")?; - let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?; + let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?; assert_eq!(top_docs.len(), 1); Ok(()) } @@ -550,7 +550,7 @@ mod tests { let get_num_hits = |query| { let (_top_docs, count) = searcher - .search(&query, &(TopDocs::with_limit(10), Count)) + .search(&query, &(TopDocs::with_limit(10).order_by_score(), Count)) .unwrap(); count }; diff --git a/src/query/range_query/range_query_fastfield.rs b/src/query/range_query/range_query_fastfield.rs index 0246ee526..54cf0cad5 100644 --- a/src/query/range_query/range_query_fastfield.rs +++ b/src/query/range_query/range_query_fastfield.rs @@ -527,7 +527,9 @@ mod tests { let test_query = |query, num_hits| { let query = query_parser.parse_query(query).unwrap(); - let top_docs = searcher.search(&query, &TopDocs::with_limit(10)).unwrap(); + let top_docs = searcher + .search(&query, &TopDocs::with_limit(10).order_by_score()) + .unwrap(); assert_eq!(top_docs.len(), num_hits); }; @@ -613,7 +615,9 @@ mod tests { let query_parser = QueryParser::for_index(&index, vec![date_field]); let test_query = |query, num_hits| { let query = query_parser.parse_query(query).unwrap(); - let top_docs = searcher.search(&query, &TopDocs::with_limit(10)).unwrap(); + let top_docs = searcher + .search(&query, &TopDocs::with_limit(10).order_by_score()) + .unwrap(); assert_eq!(top_docs.len(), num_hits); }; @@ -993,7 +997,9 @@ mod tests { let query_parser = QueryParser::for_index(&index, vec![json_field]); let test_query = |query, num_hits| { let query = query_parser.parse_query(query).unwrap(); - let top_docs = searcher.search(&query, &TopDocs::with_limit(10)).unwrap(); + let top_docs = searcher + .search(&query, &TopDocs::with_limit(10).order_by_score()) + .unwrap(); assert_eq!(top_docs.len(), num_hits); }; diff --git a/src/query/regex_query.rs b/src/query/regex_query.rs index cc5701744..1ee077c08 100644 --- a/src/query/regex_query.rs +++ b/src/query/regex_query.rs @@ -125,14 +125,20 @@ mod test { let searcher = reader.searcher(); { let scored_docs = searcher - .search(&query_matching_one, &TopDocs::with_limit(2)) + .search( + &query_matching_one, + &TopDocs::with_limit(2).order_by_score(), + ) .unwrap(); assert_eq!(scored_docs.len(), 1, "Expected only 1 document"); let (score, _) = scored_docs[0]; assert_nearly_equals!(1.0, score); } let top_docs = searcher - .search(&query_matching_zero, &TopDocs::with_limit(2)) + .search( + &query_matching_zero, + &TopDocs::with_limit(2).order_by_score(), + ) .unwrap(); assert!(top_docs.is_empty(), "Expected ZERO document"); } diff --git a/src/query/set_query.rs b/src/query/set_query.rs index 5fceac50f..12e1b2818 100644 --- a/src/query/set_query.rs +++ b/src/query/set_query.rs @@ -153,7 +153,8 @@ mod tests { let terms = vec![Term::from_field_text(field1, "doc1")]; let term_set_query = TermSetQuery::new(terms); - let top_docs = searcher.search(&term_set_query, &TopDocs::with_limit(2))?; + let top_docs = + searcher.search(&term_set_query, &TopDocs::with_limit(2).order_by_score())?; assert_eq!(top_docs.len(), 1, "Expected 1 document"); let (score, _) = top_docs[0]; assert_nearly_equals!(1.0, score); @@ -164,7 +165,8 @@ mod tests { let terms = vec![Term::from_field_text(field1, "doc4")]; let term_set_query = TermSetQuery::new(terms); - let top_docs = searcher.search(&term_set_query, &TopDocs::with_limit(1))?; + let top_docs = + searcher.search(&term_set_query, &TopDocs::with_limit(1).order_by_score())?; assert!(top_docs.is_empty(), "Expected 0 document"); } @@ -176,7 +178,8 @@ mod tests { ]; let term_set_query = TermSetQuery::new(terms); - let top_docs = searcher.search(&term_set_query, &TopDocs::with_limit(2))?; + let top_docs = + searcher.search(&term_set_query, &TopDocs::with_limit(2).order_by_score())?; assert_eq!(top_docs.len(), 2, "Expected 2 documents"); for (score, _) in top_docs { assert_nearly_equals!(1.0, score); @@ -192,7 +195,8 @@ mod tests { ]; let term_set_query = TermSetQuery::new(terms); - let top_docs = searcher.search(&term_set_query, &TopDocs::with_limit(3))?; + let top_docs = + searcher.search(&term_set_query, &TopDocs::with_limit(3).order_by_score())?; assert_eq!(top_docs.len(), 2, "Expected 2 document"); for (score, _) in top_docs { @@ -205,13 +209,15 @@ mod tests { let terms = vec![Term::from_field_text(field1, "doc3")]; let term_set_query = TermSetQuery::new(terms); - let top_docs = searcher.search(&term_set_query, &TopDocs::with_limit(3))?; + let top_docs = + searcher.search(&term_set_query, &TopDocs::with_limit(3).order_by_score())?; assert_eq!(top_docs.len(), 1, "Expected 1 document"); let terms = vec![Term::from_field_text(field2, "doc3")]; let term_set_query = TermSetQuery::new(terms); - let top_docs = searcher.search(&term_set_query, &TopDocs::with_limit(3))?; + let top_docs = + searcher.search(&term_set_query, &TopDocs::with_limit(3).order_by_score())?; assert_eq!(top_docs.len(), 1, "Expected 1 document"); let terms = vec![ @@ -220,7 +226,8 @@ mod tests { ]; let term_set_query = TermSetQuery::new(terms); - let top_docs = searcher.search(&term_set_query, &TopDocs::with_limit(3))?; + let top_docs = + searcher.search(&term_set_query, &TopDocs::with_limit(3).order_by_score())?; assert_eq!(top_docs.len(), 2, "Expected 2 document"); } @@ -249,7 +256,7 @@ mod tests { let searcher = reader.searcher(); let query_parser = QueryParser::for_index(&index, vec![]); let query = query_parser.parse_query("field: IN [val1 val2]")?; - let top_docs = searcher.search(&query, &TopDocs::with_limit(3))?; + let top_docs = searcher.search(&query, &TopDocs::with_limit(3).order_by_score())?; assert_eq!(top_docs.len(), 2); Ok(()) } diff --git a/src/query/term_query/mod.rs b/src/query/term_query/mod.rs index 1c1fa8389..0811725be 100644 --- a/src/query/term_query/mod.rs +++ b/src/query/term_query/mod.rs @@ -100,7 +100,7 @@ mod tests { { let term = Term::from_field_text(left_field, "left2"); let term_query = TermQuery::new(term, IndexRecordOption::WithFreqs); - let topdocs = searcher.search(&term_query, &TopDocs::with_limit(2))?; + let topdocs = searcher.search(&term_query, &TopDocs::with_limit(2).order_by_score())?; assert_eq!(topdocs.len(), 1); let (score, _) = topdocs[0]; assert_nearly_equals!(0.77802235, score); @@ -108,7 +108,8 @@ mod tests { { let term = Term::from_field_text(left_field, "left1"); let term_query = TermQuery::new(term, IndexRecordOption::WithFreqs); - let top_docs = searcher.search(&term_query, &TopDocs::with_limit(2))?; + let top_docs = + searcher.search(&term_query, &TopDocs::with_limit(2).order_by_score())?; assert_eq!(top_docs.len(), 2); let (score1, _) = top_docs[0]; assert_nearly_equals!(0.27101856, score1); @@ -118,7 +119,7 @@ mod tests { { let query_parser = QueryParser::for_index(&index, Vec::new()); let query = query_parser.parse_query("left:left2 left:left1")?; - let top_docs = searcher.search(&query, &TopDocs::with_limit(2))?; + let top_docs = searcher.search(&query, &TopDocs::with_limit(2).order_by_score())?; assert_eq!(top_docs.len(), 2); let (score1, _) = top_docs[0]; assert_nearly_equals!(0.9153879, score1); @@ -438,7 +439,7 @@ mod tests { // Using TopDocs requires scoring; since the field is not indexed, // TermQuery cannot score and should return a SchemaError. - let res = searcher.search(&tq, &TopDocs::with_limit(1)); + let res = searcher.search(&tq, &TopDocs::with_limit(1).order_by_score()); assert!(matches!(res, Err(crate::TantivyError::SchemaError(_)))); Ok(()) diff --git a/src/query/term_query/term_query.rs b/src/query/term_query/term_query.rs index dae6d7fe7..4d6125672 100644 --- a/src/query/term_query/term_query.rs +++ b/src/query/term_query/term_query.rs @@ -50,7 +50,7 @@ use crate::Term; /// Term::from_field_text(title, "diary"), /// IndexRecordOption::Basic, /// ); -/// let (top_docs, count) = searcher.search(&query, &(TopDocs::with_limit(2), Count))?; +/// let (top_docs, count) = searcher.search(&query, &(TopDocs::with_limit(2).order_by_score(), Count))?; /// assert_eq!(count, 2); /// Ok(()) /// # } @@ -190,7 +190,7 @@ mod tests { let assert_single_hit = |query| { let (_top_docs, count) = searcher - .search(&query, &(TopDocs::with_limit(2), Count)) + .search(&query, &(TopDocs::with_limit(2).order_by_score(), Count)) .unwrap(); assert_eq!(count, 1); };