diff --git a/src/collector/tests.rs b/src/collector/tests.rs index f348b4c15..53c32e6a2 100644 --- a/src/collector/tests.rs +++ b/src/collector/tests.rs @@ -8,13 +8,23 @@ use crate::DocId; use crate::Score; use crate::SegmentLocalId; +pub const TEST_COLLECTOR_WITH_SCORE: TestCollector = TestCollector { + compute_score: true, +}; + +pub const TEST_COLLECTOR_WITHOUT_SCORE: TestCollector = TestCollector { + compute_score: true, +}; + /// Stores all of the doc ids. /// This collector is only used for tests. /// It is unusable in pr /// /// actise, as it does not store /// the segment ordinals -pub struct TestCollector; +pub struct TestCollector { + pub compute_score: bool, +} pub struct TestSegmentCollector { segment_id: SegmentLocalId, @@ -32,7 +42,6 @@ impl TestFruit { pub fn docs(&self) -> &[DocAddress] { &self.docs[..] } - pub fn scores(&self) -> &[Score] { &self.scores[..] } @@ -54,7 +63,7 @@ impl Collector for TestCollector { } fn requires_scoring(&self) -> bool { - true + self.compute_score } fn merge_fruits(&self, mut children: Vec) -> Result { diff --git a/src/indexer/merger.rs b/src/indexer/merger.rs index 4d15d7ffa..d01f351ae 100644 --- a/src/indexer/merger.rs +++ b/src/indexer/merger.rs @@ -692,7 +692,7 @@ impl SerializableSegment for IndexMerger { #[cfg(test)] mod tests { - use crate::collector::tests::TestCollector; + use crate::collector::tests::TEST_COLLECTOR_WITH_SCORE; use crate::collector::tests::{BytesFastFieldTestCollector, FastFieldTestCollector}; use crate::collector::{Count, FacetCollector}; use crate::core::Index; @@ -807,7 +807,7 @@ mod tests { let searcher = reader.searcher(); let get_doc_ids = |terms: Vec| { let query = BooleanQuery::new_multiterms_query(terms); - let top_docs = searcher.search(&query, &TestCollector).unwrap(); + let top_docs = searcher.search(&query, &TEST_COLLECTOR_WITH_SCORE).unwrap(); top_docs.docs().to_vec() }; { diff --git a/src/lib.rs b/src/lib.rs index 85d2541a1..be29b623a 100755 --- a/src/lib.rs +++ b/src/lib.rs @@ -250,7 +250,7 @@ pub struct DocAddress(pub SegmentLocalId, pub DocId); #[cfg(test)] mod tests { - use crate::collector::tests::TestCollector; + use crate::collector::tests::TEST_COLLECTOR_WITH_SCORE; use crate::core::SegmentReader; use crate::docset::DocSet; use crate::query::BooleanQuery; @@ -737,7 +737,7 @@ mod tests { let searcher = reader.searcher(); let get_doc_ids = |terms: Vec| { let query = BooleanQuery::new_multiterms_query(terms); - let topdocs = searcher.search(&query, &TestCollector).unwrap(); + let topdocs = searcher.search(&query, &TEST_COLLECTOR_WITH_SCORE).unwrap(); topdocs.docs().to_vec() }; assert_eq!( diff --git a/src/query/boolean_query/mod.rs b/src/query/boolean_query/mod.rs index 04f248322..47a3b0aab 100644 --- a/src/query/boolean_query/mod.rs +++ b/src/query/boolean_query/mod.rs @@ -7,7 +7,7 @@ pub use self::boolean_query::BooleanQuery; mod tests { use super::*; - use crate::collector::tests::TestCollector; + use crate::collector::tests::TEST_COLLECTOR_WITH_SCORE; use crate::query::score_combiner::SumWithCoordsCombiner; use crate::query::term_query::TermScorer; use crate::query::Intersection; @@ -134,7 +134,7 @@ mod tests { let matching_docs = |boolean_query: &dyn Query| { reader .searcher() - .search(boolean_query, &TestCollector) + .search(boolean_query, &TEST_COLLECTOR_WITH_SCORE) .unwrap() .docs() .iter() @@ -195,7 +195,7 @@ mod tests { let score_docs = |boolean_query: &dyn Query| { let fruit = reader .searcher() - .search(boolean_query, &TestCollector) + .search(boolean_query, &TEST_COLLECTOR_WITH_SCORE) .unwrap(); fruit.scores().to_vec() }; diff --git a/src/query/phrase_query/mod.rs b/src/query/phrase_query/mod.rs index 306ad2730..d43b65f31 100644 --- a/src/query/phrase_query/mod.rs +++ b/src/query/phrase_query/mod.rs @@ -10,13 +10,13 @@ pub use self::phrase_weight::PhraseWeight; mod tests { use super::*; - use crate::collector::tests::TestCollector; + use crate::collector::tests::{TEST_COLLECTOR_WITHOUT_SCORE, TEST_COLLECTOR_WITH_SCORE}; use crate::core::Index; use crate::error::TantivyError; use crate::schema::{Schema, Term, TEXT}; use crate::tests::assert_nearly_equals; - use crate::DocAddress; use crate::DocId; + use crate::{DocAddress, DocSet}; fn create_index(texts: &[&'static str]) -> Index { let mut schema_builder = Schema::builder(); @@ -53,7 +53,7 @@ mod tests { .collect(); let phrase_query = PhraseQuery::new(terms); let test_fruits = searcher - .search(&phrase_query, &TestCollector) + .search(&phrase_query, &TEST_COLLECTOR_WITH_SCORE) .expect("search should succeed"); test_fruits .docs() @@ -68,6 +68,64 @@ mod tests { assert!(test_query(vec!["g", "a"]).is_empty()); } + #[test] + pub fn test_phrase_query_no_score() { + let index = create_index(&[ + "b b b d c g c", + "a b b d c g c", + "a b a b c", + "c a b a d ga a", + "a b c", + ]); + let schema = index.schema(); + let text_field = schema.get_field("text").unwrap(); + let searcher = index.reader().unwrap().searcher(); + let test_query = |texts: Vec<&str>| { + let terms: Vec = texts + .iter() + .map(|text| Term::from_field_text(text_field, text)) + .collect(); + let phrase_query = PhraseQuery::new(terms); + let test_fruits = searcher + .search(&phrase_query, &TEST_COLLECTOR_WITHOUT_SCORE) + .expect("search should succeed"); + test_fruits + .docs() + .iter() + .map(|docaddr| docaddr.1) + .collect::>() + }; + assert_eq!(test_query(vec!["a", "b", "c"]), vec![2, 4]); + assert_eq!(test_query(vec!["a", "b"]), vec![1, 2, 3, 4]); + assert_eq!(test_query(vec!["b", "b"]), vec![0, 1]); + assert!(test_query(vec!["g", "ewrwer"]).is_empty()); + assert!(test_query(vec!["g", "a"]).is_empty()); + } + + #[test] + pub fn test_phrase_count() { + let index = create_index(&["a c", "a a b d a b c", " a b"]); + let schema = index.schema(); + let text_field = schema.get_field("text").unwrap(); + let searcher = index.reader().unwrap().searcher(); + let phrase_query = PhraseQuery::new(vec![ + Term::from_field_text(text_field, "a"), + Term::from_field_text(text_field, "b"), + ]); + let phrase_weight = phrase_query.phrase_weight(&searcher, true).unwrap(); + let mut phrase_scorer = phrase_weight + .phrase_scorer(searcher.segment_reader(0u32)) + .unwrap() + .unwrap(); + assert!(phrase_scorer.advance()); + assert_eq!(phrase_scorer.doc(), 1); + assert_eq!(phrase_scorer.phrase_count(), 2); + assert!(phrase_scorer.advance()); + assert_eq!(phrase_scorer.doc(), 2); + assert_eq!(phrase_scorer.phrase_count(), 1); + assert!(!phrase_scorer.advance()); + } + #[test] pub fn test_phrase_query_no_positions() { let mut schema_builder = Schema::builder(); @@ -93,17 +151,20 @@ mod tests { Term::from_field_text(text_field, "a"), Term::from_field_text(text_field, "b"), ]); - if let TantivyError::SchemaError(ref msg) = searcher - .search(&phrase_query, &TestCollector) + match searcher + .search(&phrase_query, &TEST_COLLECTOR_WITH_SCORE) .map(|_| ()) .unwrap_err() { - assert_eq!( - "Applied phrase query on field \"text\", which does not have positions indexed", - msg.as_str() - ); - } else { - panic!("Should have returned an error"); + TantivyError::SchemaError(ref msg) => { + assert_eq!( + "Applied phrase query on field \"text\", which does not have positions indexed", + msg.as_str() + ); + } + _ => { + panic!("Should have returned an error"); + } } } @@ -120,7 +181,7 @@ mod tests { .collect(); let phrase_query = PhraseQuery::new(terms); searcher - .search(&phrase_query, &TestCollector) + .search(&phrase_query, &TEST_COLLECTOR_WITH_SCORE) .expect("search should succeed") .scores() .to_vec() @@ -152,7 +213,7 @@ mod tests { .collect(); let phrase_query = PhraseQuery::new(terms); searcher - .search(&phrase_query, &TestCollector) + .search(&phrase_query, &TEST_COLLECTOR_WITH_SCORE) .expect("search should succeed") .docs() .to_vec() @@ -180,7 +241,7 @@ mod tests { .collect(); let phrase_query = PhraseQuery::new_with_offset(terms); searcher - .search(&phrase_query, &TestCollector) + .search(&phrase_query, &TEST_COLLECTOR_WITH_SCORE) .expect("search should succeed") .docs() .iter() diff --git a/src/query/phrase_query/phrase_query.rs b/src/query/phrase_query/phrase_query.rs index a4e9d708b..8c1126e16 100644 --- a/src/query/phrase_query/phrase_query.rs +++ b/src/query/phrase_query/phrase_query.rs @@ -72,13 +72,16 @@ impl PhraseQuery { .map(|(_, term)| term.clone()) .collect::>() } -} -impl Query for PhraseQuery { - /// Create the weight associated to a query. + /// Returns the `PhraseWeight` for the given phrase query given a specific `searcher`. /// - /// See [`Weight`](./trait.Weight.html). - fn weight(&self, searcher: &Searcher, scoring_enabled: bool) -> Result> { + /// This function is the same as `.weight(...)` except it returns + /// a specialized type `PhraseWeight` instead of a Boxed trait. + pub(crate) fn phrase_weight( + &self, + searcher: &Searcher, + scoring_enabled: bool, + ) -> Result { let schema = searcher.schema(); let field_entry = schema.get_field_entry(self.field); let has_positions = field_entry @@ -95,9 +98,20 @@ impl Query for PhraseQuery { } let terms = self.phrase_terms(); let bm25_weight = BM25Weight::for_terms(searcher, &terms); + Ok(PhraseWeight::new( + self.phrase_terms.clone(), + bm25_weight, + scoring_enabled, + )) + } +} - let phrase_weight: PhraseWeight = - PhraseWeight::new(self.phrase_terms.clone(), bm25_weight, scoring_enabled); +impl Query for PhraseQuery { + /// Create the weight associated to a query. + /// + /// See [`Weight`](./trait.Weight.html). + fn weight(&self, searcher: &Searcher, scoring_enabled: bool) -> Result> { + let phrase_weight = self.phrase_weight(searcher, scoring_enabled)?; Ok(Box::new(phrase_weight)) } diff --git a/src/query/phrase_query/phrase_scorer.rs b/src/query/phrase_query/phrase_scorer.rs index c85cfa8cc..e1d4fdccf 100644 --- a/src/query/phrase_query/phrase_scorer.rs +++ b/src/query/phrase_query/phrase_scorer.rs @@ -163,11 +163,9 @@ impl PhraseScorer { } fn phrase_exists(&mut self) -> bool { - { - self.intersection_docset - .docset_mut_specialized(0) - .positions(&mut self.left); - } + self.intersection_docset + .docset_mut_specialized(0) + .positions(&mut self.left); let mut intersection_len = self.left.len(); for i in 1..self.num_terms - 1 { { diff --git a/src/query/phrase_query/phrase_weight.rs b/src/query/phrase_query/phrase_weight.rs index b60ce5a07..1fea04ec8 100644 --- a/src/query/phrase_query/phrase_weight.rs +++ b/src/query/phrase_query/phrase_weight.rs @@ -37,7 +37,7 @@ impl PhraseWeight { reader.get_fieldnorms_reader(field) } - fn phrase_scorer( + pub fn phrase_scorer( &self, reader: &SegmentReader, ) -> Result>> {