From 2c3e33895af57cae38911eff0746a24a0772aecc Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Wed, 21 Feb 2018 00:03:41 +0900 Subject: [PATCH] Added unit tests --- Cargo.toml | 2 +- src/query/boolean_query/mod.rs | 58 ++++++++++++++++++++++++- src/query/query_parser/query_grammar.rs | 7 +-- 3 files changed, 61 insertions(+), 6 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3d5569b4a..a9d6cba51 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,7 +41,7 @@ error-chain = "0.8" owning_ref = "0.3" stable_deref_trait = "1.0.0" rust-stemmers = "0.1.0" -downcast = "0.9" +downcast = { version="0.9", features = ["nightly"]} matches = "0.1" [target.'cfg(windows)'.dependencies] diff --git a/src/query/boolean_query/mod.rs b/src/query/boolean_query/mod.rs index ec132202c..fd7393050 100644 --- a/src/query/boolean_query/mod.rs +++ b/src/query/boolean_query/mod.rs @@ -10,11 +10,17 @@ mod tests { use query::Occur; use query::Query; use query::TermQuery; + use query::Intersection; + use query::Scorer; + use query::term_query::TermScorer; use collector::tests::TestCollector; use Index; + use downcast::Downcast; use schema::*; - use schema::IndexRecordOption; use query::QueryParser; + use query::RequiredOptionalScorer; + use query::score_combiner::SumWithCoordsCombiner; + fn aux_test_helper() -> (Index, Field) { @@ -56,10 +62,58 @@ mod tests { let (index, text_field) = aux_test_helper(); let query_parser = QueryParser::for_index(&index, vec![text_field]); let query = query_parser.parse_query("(+a +b) d").unwrap(); - println!("{:?}", query); assert_eq!(query.count(&*index.searcher()).unwrap(), 3); } + #[test] + pub fn test_boolean_single_must_clause() { + let (index, text_field) = aux_test_helper(); + let query_parser = QueryParser::for_index(&index, vec![text_field]); + let query = query_parser.parse_query("+a").unwrap(); + let searcher = index.searcher(); + let weight = query.weight(&*searcher, true).unwrap(); + let scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap(); + assert!(Downcast::::is_type(&*scorer)); + } + + #[test] + pub fn test_boolean_termonly_intersection() { + let (index, text_field) = aux_test_helper(); + let query_parser = QueryParser::for_index(&index, vec![text_field]); + let searcher = index.searcher(); + { + let query = query_parser.parse_query("+a +b +c").unwrap(); + let weight = query.weight(&*searcher, true).unwrap(); + let scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap(); + assert!(Downcast::>::is_type(&*scorer)); + } + { + let query = query_parser.parse_query("+a +(b c)").unwrap(); + let weight = query.weight(&*searcher, true).unwrap(); + let scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap(); + assert!(Downcast::>>::is_type(&*scorer)); + } + } + + #[test] + pub fn test_boolean_reqopt() { + let (index, text_field) = aux_test_helper(); + let query_parser = QueryParser::for_index(&index, vec![text_field]); + let searcher = index.searcher(); + { + let query = query_parser.parse_query("+a b").unwrap(); + let weight = query.weight(&*searcher, true).unwrap(); + let scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap(); + assert!(Downcast::, Box, SumWithCoordsCombiner>>::is_type(&*scorer)); + } + { + let query = query_parser.parse_query("+a b").unwrap(); + let weight = query.weight(&*searcher, false).unwrap(); + let scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap(); + assert!(Downcast::::is_type(&*scorer)); + } + } + #[test] pub fn test_boolean_query() { diff --git a/src/query/query_parser/query_grammar.rs b/src/query/query_parser/query_grammar.rs index ae5d69bda..63691a6d9 100644 --- a/src/query/query_parser/query_grammar.rs +++ b/src/query/query_parser/query_grammar.rs @@ -41,10 +41,10 @@ fn leaf(input: I) -> ParseResult where I: Stream, { - (char('-'), parser(literal)).map(|(_, expr)| UserInputAST::Not(box expr)) - .or((char('+'), parser(literal)).map(|(_, expr)| UserInputAST::Must(box expr))) - .or(parser(literal)) + (char('-'), parser(leaf)).map(|(_, expr)| UserInputAST::Not(box expr)) + .or((char('+'), parser(leaf)).map(|(_, expr)| UserInputAST::Must(box expr))) .or((char('('), parser(parse_to_ast), char(')')).map(|(_, expr, _)| expr)) + .or(parser(literal)) .parse_stream(input) } @@ -80,6 +80,7 @@ mod test { #[test] fn test_parse_query_to_ast() { + test_parse_query_to_ast_helper("+(a b) +d", "(+((\"a\" \"b\")) +(\"d\"))"); test_parse_query_to_ast_helper("(+a +b) d", "((+(\"a\") +(\"b\")) \"d\")"); test_parse_query_to_ast_helper("(+a)", "+(\"a\")"); test_parse_query_to_ast_helper("(+a +b)", "(+(\"a\") +(\"b\"))");