diff --git a/src/core/searcher.rs b/src/core/searcher.rs index 6d8d14392..ab23a73d0 100644 --- a/src/core/searcher.rs +++ b/src/core/searcher.rs @@ -198,11 +198,10 @@ impl Searcher { collector: &C, executor: &Executor, ) -> crate::Result { - let scoring_enabled = collector.requires_scoring(); - let enabled_scoring = if scoring_enabled { - EnableScoring::Enabled(self) + let enabled_scoring = if collector.requires_scoring() { + EnableScoring::enabled_from_searcher(self) } else { - EnableScoring::Disabled(self.schema()) + EnableScoring::disabled_from_searcher(self) }; let weight = query.weight(enabled_scoring)?; let segment_readers = self.segment_readers(); diff --git a/src/fastfield/bytes/mod.rs b/src/fastfield/bytes/mod.rs index d120431e1..a476639bc 100644 --- a/src/fastfield/bytes/mod.rs +++ b/src/fastfield/bytes/mod.rs @@ -96,7 +96,7 @@ mod tests { let term = Term::from_field_bytes(field, b"lucene".as_ref()); let term_query = TermQuery::new(term, IndexRecordOption::Basic); let term_weight_err = - term_query.specialized_weight(EnableScoring::Disabled(searcher.schema())); + term_query.specialized_weight(EnableScoring::disabled_from_schema(searcher.schema())); assert!(matches!( term_weight_err, Err(crate::TantivyError::SchemaError(_)) diff --git a/src/indexer/index_writer.rs b/src/indexer/index_writer.rs index 229595948..4ba6441d0 100644 --- a/src/indexer/index_writer.rs +++ b/src/indexer/index_writer.rs @@ -678,7 +678,7 @@ impl IndexWriter { /// only after calling `commit()`. #[doc(hidden)] pub fn delete_query(&self, query: Box) -> crate::Result { - let weight = query.weight(EnableScoring::Disabled(&self.index.schema()))?; + let weight = query.weight(EnableScoring::disabled_from_schema(&self.index.schema()))?; let opstamp = self.stamper.stamp(); let delete_operation = DeleteOperation { opstamp, @@ -759,7 +759,8 @@ impl IndexWriter { match user_op { UserOperation::Delete(term) => { let query = TermQuery::new(term, IndexRecordOption::Basic); - let weight = query.weight(EnableScoring::Disabled(&self.index.schema()))?; + let weight = + query.weight(EnableScoring::disabled_from_schema(&self.index.schema()))?; let delete_operation = DeleteOperation { opstamp, target: weight, diff --git a/src/query/all_query.rs b/src/query/all_query.rs index 669649454..31281ba05 100644 --- a/src/query/all_query.rs +++ b/src/query/all_query.rs @@ -95,7 +95,7 @@ mod tests { let index = create_test_index()?; let reader = index.reader()?; let searcher = reader.searcher(); - let weight = AllQuery.weight(EnableScoring::Disabled(&index.schema()))?; + let weight = AllQuery.weight(EnableScoring::disabled_from_schema(&index.schema()))?; { let reader = searcher.segment_reader(0); let mut scorer = weight.scorer(reader, 1.0)?; @@ -118,7 +118,7 @@ mod tests { let index = create_test_index()?; let reader = index.reader()?; let searcher = reader.searcher(); - let weight = AllQuery.weight(EnableScoring::Disabled(searcher.schema()))?; + let weight = AllQuery.weight(EnableScoring::disabled_from_schema(searcher.schema()))?; let reader = searcher.segment_reader(0); { let mut scorer = weight.scorer(reader, 2.0)?; diff --git a/src/query/boolean_query/mod.rs b/src/query/boolean_query/mod.rs index 404c8d77d..219f6b725 100644 --- a/src/query/boolean_query/mod.rs +++ b/src/query/boolean_query/mod.rs @@ -98,7 +98,7 @@ mod tests { } { let query = query_parser.parse_query("+a b")?; - let weight = query.weight(EnableScoring::Disabled(searcher.schema()))?; + let weight = query.weight(EnableScoring::disabled_from_schema(searcher.schema()))?; let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0)?; assert!(scorer.is::()); } diff --git a/src/query/more_like_this/query.rs b/src/query/more_like_this/query.rs index 125a73075..9e6c36424 100644 --- a/src/query/more_like_this/query.rs +++ b/src/query/more_like_this/query.rs @@ -45,7 +45,7 @@ impl Query for MoreLikeThisQuery { fn weight(&self, enable_scoring: EnableScoring<'_>) -> crate::Result> { let searcher = match enable_scoring { EnableScoring::Enabled(searcher) => searcher, - EnableScoring::Disabled(_) => { + EnableScoring::Disabled { .. } => { let err = "MoreLikeThisQuery requires to enable scoring.".to_string(); return Err(crate::TantivyError::InvalidArgument(err)); } diff --git a/src/query/phrase_query/mod.rs b/src/query/phrase_query/mod.rs index 3ebee8475..984659fc9 100644 --- a/src/query/phrase_query/mod.rs +++ b/src/query/phrase_query/mod.rs @@ -80,7 +80,7 @@ pub mod tests { .collect(); let phrase_query = PhraseQuery::new(terms); let phrase_weight = - phrase_query.phrase_weight(EnableScoring::Disabled(searcher.schema()))?; + phrase_query.phrase_weight(EnableScoring::disabled_from_schema(searcher.schema()))?; let mut phrase_scorer = phrase_weight.scorer(searcher.segment_reader(0), 1.0)?; assert_eq!(phrase_scorer.doc(), 1); assert_eq!(phrase_scorer.advance(), TERMINATED); @@ -361,7 +361,7 @@ pub mod tests { let query_parser = QueryParser::for_index(&index, vec![json_field]); let phrase_query = query_parser.parse_query(query).unwrap(); let phrase_weight = phrase_query - .weight(EnableScoring::Disabled(searcher.schema())) + .weight(EnableScoring::disabled_from_schema(searcher.schema())) .unwrap(); let mut phrase_scorer = phrase_weight .scorer(searcher.segment_reader(0), 1.0f32) diff --git a/src/query/phrase_query/phrase_query.rs b/src/query/phrase_query/phrase_query.rs index 284663997..efc31a044 100644 --- a/src/query/phrase_query/phrase_query.rs +++ b/src/query/phrase_query/phrase_query.rs @@ -109,7 +109,7 @@ impl PhraseQuery { let terms = self.phrase_terms(); let bm25_weight_opt = match enable_scoring { EnableScoring::Enabled(searcher) => Some(Bm25Weight::for_terms(searcher, &terms)?), - EnableScoring::Disabled(_) => None, + EnableScoring::Disabled { .. } => None, }; let mut weight = PhraseWeight::new(self.phrase_terms.clone(), bm25_weight_opt); if self.slop > 0 { diff --git a/src/query/query.rs b/src/query/query.rs index df518d2bc..3899a81d5 100644 --- a/src/query/query.rs +++ b/src/query/query.rs @@ -15,24 +15,55 @@ pub enum EnableScoring<'a> { Enabled(&'a Searcher), /// Pass this to disable scoring. /// This can improve performance. - Disabled(&'a Schema), + Disabled { + /// Schema is required. + schema: &'a Schema, + /// Searcher should be provided if available. + searcher_opt: Option<&'a Searcher>, + }, } impl<'a> EnableScoring<'a> { + /// Create using [Searcher] with scoring enabled. + pub fn enabled_from_searcher(searcher: &'a Searcher) -> EnableScoring<'a> { + EnableScoring::Enabled(searcher) + } + + /// Create using [Searcher] with scoring disabled. + pub fn disabled_from_searcher(searcher: &'a Searcher) -> EnableScoring<'a> { + EnableScoring::Disabled { + schema: searcher.schema(), + searcher_opt: Some(searcher), + } + } + + /// Create using [Schema] with scoring disabled. + pub fn disabled_from_schema(schema: &'a Schema) -> EnableScoring<'a> { + Self::Disabled { + schema, + searcher_opt: None, + } + } + + /// Returns the searcher if available. + pub fn searcher(&self) -> Option<&Searcher> { + match self { + EnableScoring::Enabled(searcher) => Some(searcher), + EnableScoring::Disabled { searcher_opt, .. } => searcher_opt.to_owned(), + } + } + /// Returns the schema. pub fn schema(&self) -> &Schema { match self { EnableScoring::Enabled(searcher) => searcher.schema(), - EnableScoring::Disabled(schema) => schema, + EnableScoring::Disabled { schema, .. } => schema, } } /// Returns true if the scoring is enabled. pub fn is_scoring_enabled(&self) -> bool { - match self { - EnableScoring::Enabled(_) => true, - EnableScoring::Disabled(_) => false, - } + matches!(self, EnableScoring::Enabled(..)) } } @@ -81,14 +112,14 @@ pub trait Query: QueryClone + Send + Sync + downcast_rs::Downcast + fmt::Debug { /// Returns an `Explanation` for the score of the document. fn explain(&self, searcher: &Searcher, doc_address: DocAddress) -> crate::Result { - let weight = self.weight(EnableScoring::Enabled(searcher))?; + let weight = self.weight(EnableScoring::enabled_from_searcher(searcher))?; let reader = searcher.segment_reader(doc_address.segment_ord); weight.explain(reader, doc_address.doc_id) } /// Returns the number of documents matching the query. fn count(&self, searcher: &Searcher) -> crate::Result { - let weight = self.weight(EnableScoring::Disabled(searcher.schema()))?; + let weight = self.weight(EnableScoring::disabled_from_searcher(searcher))?; let mut result = 0; for reader in searcher.segment_readers() { result += weight.count(reader)? as usize; diff --git a/src/query/term_query/mod.rs b/src/query/term_query/mod.rs index 53cf9750c..e07d313c6 100644 --- a/src/query/term_query/mod.rs +++ b/src/query/term_query/mod.rs @@ -158,7 +158,8 @@ mod tests { let term_a = Term::from_field_text(text_field, "a"); let term_query = TermQuery::new(term_a, IndexRecordOption::Basic); let searcher = index.reader()?.searcher(); - let term_weight = term_query.weight(EnableScoring::Disabled(searcher.schema()))?; + let term_weight = + term_query.weight(EnableScoring::disabled_from_schema(searcher.schema()))?; let mut term_scorer = term_weight.scorer(searcher.segment_reader(0u32), 1.0)?; assert_eq!(term_scorer.doc(), 0u32); term_scorer.seek(1u32); diff --git a/src/query/term_query/term_query.rs b/src/query/term_query/term_query.rs index 0b1cfb82d..bc13fafc2 100644 --- a/src/query/term_query/term_query.rs +++ b/src/query/term_query/term_query.rs @@ -99,7 +99,7 @@ impl TermQuery { EnableScoring::Enabled(searcher) => { Bm25Weight::for_terms(searcher, &[self.term.clone()])? } - EnableScoring::Disabled(_schema) => { + EnableScoring::Disabled { .. } => { Bm25Weight::new(Explanation::new("".to_string(), 1.0f32), 1.0f32) } };