diff --git a/src/lib.rs b/src/lib.rs index 7367050ce..f226778bf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -298,17 +298,26 @@ mod tests { use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; - pub fn assert_nearly_equals(expected: f32, val: f32) { - assert!( - nearly_equals(val, expected), - "Got {}, expected {}.", - val, - expected - ); - } - - pub fn nearly_equals(a: f32, b: f32) -> bool { - (a - b).abs() < 0.0005 * (a + b).abs() + /// Checks if left and right are close one to each other. + /// Panics if the two values are more than 0.5% apart. + #[macro_export] + macro_rules! assert_nearly_equals { + ($left:expr, $right:expr) => {{ + match (&$left, &$right) { + (left_val, right_val) => { + let diff = (left_val - right_val).abs(); + let add = left_val.abs() + right_val.abs(); + if diff > 0.0005 * add { + panic!( + r#"assertion failed: `(left ~= right)` + left: `{:?}`, + right: `{:?}`"#, + &*left_val, &*right_val + ) + } + } + } + }}; } pub fn generate_nonunique_unsorted(max_value: u32, n_elems: usize) -> Vec { diff --git a/src/query/bm25.rs b/src/query/bm25.rs index 48f84fece..4c9580d71 100644 --- a/src/query/bm25.rs +++ b/src/query/bm25.rs @@ -139,10 +139,10 @@ impl BM25Weight { mod tests { use super::idf; - use crate::tests::assert_nearly_equals; + use crate::assert_nearly_equals; #[test] fn test_idf() { - assert_nearly_equals(idf(1, 2), 0.6931472); + assert_nearly_equals!(idf(1, 2), 0.6931472); } } diff --git a/src/query/boolean_query/mod.rs b/src/query/boolean_query/mod.rs index 61aa46a5c..a90a60e92 100644 --- a/src/query/boolean_query/mod.rs +++ b/src/query/boolean_query/mod.rs @@ -7,6 +7,7 @@ pub use self::boolean_query::BooleanQuery; mod tests { use super::*; + use crate::assert_nearly_equals; use crate::collector::tests::TEST_COLLECTOR_WITH_SCORE; use crate::collector::TopDocs; use crate::query::score_combiner::SumWithCoordsCombiner; @@ -19,7 +20,6 @@ mod tests { use crate::query::Scorer; use crate::query::TermQuery; use crate::schema::*; - use crate::tests::assert_nearly_equals; use crate::Index; use crate::{DocAddress, DocId, Score}; @@ -256,14 +256,14 @@ mod tests { .scorer(searcher.segment_reader(0u32), 1.0f32) .unwrap(); assert_eq!(boolean_scorer.doc(), 0u32); - assert_nearly_equals(boolean_scorer.score(), 0.84163445f32); + assert_nearly_equals!(boolean_scorer.score(), 0.84163445f32); } { let mut boolean_scorer = boolean_weight .scorer(searcher.segment_reader(0u32), 2.0f32) .unwrap(); assert_eq!(boolean_scorer.doc(), 0u32); - assert_nearly_equals(boolean_scorer.score(), 1.6832689f32); + assert_nearly_equals!(boolean_scorer.score(), 1.6832689f32); } } diff --git a/src/query/fuzzy_query.rs b/src/query/fuzzy_query.rs index 35c8d5654..a7429d548 100644 --- a/src/query/fuzzy_query.rs +++ b/src/query/fuzzy_query.rs @@ -163,10 +163,10 @@ impl Query for FuzzyTermQuery { #[cfg(test)] mod test { use super::FuzzyTermQuery; + use crate::assert_nearly_equals; use crate::collector::TopDocs; use crate::schema::Schema; use crate::schema::TEXT; - use crate::tests::assert_nearly_equals; use crate::Index; use crate::Term; @@ -199,7 +199,7 @@ mod test { .unwrap(); assert_eq!(top_docs.len(), 1, "Expected only 1 document"); let (score, _) = top_docs[0]; - assert_nearly_equals(1f32, score); + assert_nearly_equals!(1f32, score); } // fails because non-prefix Levenshtein distance is more than 1 (add 'a' and 'n') @@ -223,7 +223,7 @@ mod test { .unwrap(); assert_eq!(top_docs.len(), 1, "Expected only 1 document"); let (score, _) = top_docs[0]; - assert_nearly_equals(1f32, score); + assert_nearly_equals!(1f32, score); } } } diff --git a/src/query/phrase_query/mod.rs b/src/query/phrase_query/mod.rs index fe3b2895a..f8967f8a8 100644 --- a/src/query/phrase_query/mod.rs +++ b/src/query/phrase_query/mod.rs @@ -10,11 +10,11 @@ pub use self::phrase_weight::PhraseWeight; pub mod tests { use super::*; + use crate::assert_nearly_equals; use crate::collector::tests::{TEST_COLLECTOR_WITHOUT_SCORE, TEST_COLLECTOR_WITH_SCORE}; use crate::core::Index; use crate::query::Weight; use crate::schema::{Schema, Term, TEXT}; - use crate::tests::assert_nearly_equals; use crate::DocId; use crate::{DocAddress, TERMINATED}; @@ -175,8 +175,8 @@ pub mod tests { .to_vec() }; let scores = test_query(vec!["a", "b"]); - assert_nearly_equals(scores[0], 0.40618482); - assert_nearly_equals(scores[1], 0.46844664); + assert_nearly_equals!(scores[0], 0.40618482); + assert_nearly_equals!(scores[1], 0.46844664); } #[test] // motivated by #234 diff --git a/src/query/regex_query.rs b/src/query/regex_query.rs index 42c521ce3..6672a8c42 100644 --- a/src/query/regex_query.rs +++ b/src/query/regex_query.rs @@ -89,10 +89,10 @@ impl Query for RegexQuery { #[cfg(test)] mod test { use super::RegexQuery; + use crate::assert_nearly_equals; use crate::collector::TopDocs; use crate::schema::TEXT; use crate::schema::{Field, Schema}; - use crate::tests::assert_nearly_equals; use crate::{Index, IndexReader}; use std::sync::Arc; use tantivy_fst::Regex; @@ -129,7 +129,7 @@ mod test { .unwrap(); assert_eq!(scored_docs.len(), 1, "Expected only 1 document"); let (score, _) = scored_docs[0]; - assert_nearly_equals(1f32, score); + assert_nearly_equals!(1f32, score); } let top_docs = searcher .search(&query_matching_zero, &TopDocs::with_limit(2)) diff --git a/src/query/term_query/mod.rs b/src/query/term_query/mod.rs index 69b011215..d69c557b6 100644 --- a/src/query/term_query/mod.rs +++ b/src/query/term_query/mod.rs @@ -9,12 +9,12 @@ pub use self::term_weight::TermWeight; #[cfg(test)] mod tests { + use crate::assert_nearly_equals; use crate::collector::TopDocs; use crate::docset::DocSet; use crate::postings::compression::COMPRESSION_BLOCK_SIZE; use crate::query::{Query, QueryParser, Scorer, TermQuery}; use crate::schema::{Field, IndexRecordOption, Schema, STRING, TEXT}; - use crate::tests::assert_nearly_equals; use crate::Term; use crate::{Index, TERMINATED}; @@ -105,7 +105,7 @@ mod tests { .unwrap(); assert_eq!(topdocs.len(), 1); let (score, _) = topdocs[0]; - assert_nearly_equals(0.77802235, score); + assert_nearly_equals!(0.77802235, score); } { let term = Term::from_field_text(left_field, "left1"); @@ -115,9 +115,9 @@ mod tests { .unwrap(); assert_eq!(top_docs.len(), 2); let (score1, _) = top_docs[0]; - assert_nearly_equals(0.27101856, score1); + assert_nearly_equals!(0.27101856, score1); let (score2, _) = top_docs[1]; - assert_nearly_equals(0.13736556, score2); + assert_nearly_equals!(0.13736556, score2); } { let query_parser = QueryParser::for_index(&index, vec![]); @@ -125,9 +125,9 @@ mod tests { let top_docs = searcher.search(&query, &TopDocs::with_limit(2)).unwrap(); assert_eq!(top_docs.len(), 2); let (score1, _) = top_docs[0]; - assert_nearly_equals(0.9153879, score1); + assert_nearly_equals!(0.9153879, score1); let (score2, _) = top_docs[1]; - assert_nearly_equals(0.27101856, score2); + assert_nearly_equals!(0.27101856, score2); } } diff --git a/src/query/term_query/term_weight.rs b/src/query/term_query/term_weight.rs index 57ca3f87e..7e911b288 100644 --- a/src/query/term_query/term_weight.rs +++ b/src/query/term_query/term_weight.rs @@ -20,12 +20,12 @@ pub struct TermWeight { impl Weight for TermWeight { fn scorer(&self, reader: &SegmentReader, boost: f32) -> Result> { - let term_scorer = self.scorer_specialized(reader, boost)?; + let term_scorer = self.specialized_scorer(reader, boost)?; Ok(Box::new(term_scorer)) } fn explain(&self, reader: &SegmentReader, doc: DocId) -> Result { - let mut scorer = self.scorer_specialized(reader, 1.0f32)?; + let mut scorer = self.specialized_scorer(reader, 1.0f32)?; if scorer.seek(doc) != doc { return Err(does_not_match(doc)); } @@ -52,7 +52,7 @@ impl Weight for TermWeight { reader: &SegmentReader, callback: &mut dyn FnMut(DocId, Score), ) -> crate::Result<()> { - let mut scorer = self.scorer_specialized(reader, 1.0f32)?; + let mut scorer = self.specialized_scorer(reader, 1.0f32)?; for_each_scorer(&mut scorer, callback); Ok(()) } @@ -92,7 +92,7 @@ impl TermWeight { } } - fn scorer_specialized(&self, reader: &SegmentReader, boost: f32) -> Result { + fn specialized_scorer(&self, reader: &SegmentReader, boost: f32) -> Result { let field = self.term.field(); let inverted_index = reader.inverted_index(field); let fieldnorm_reader = reader.get_fieldnorms_reader(field);