From 317baf4e75fdd0ce4f98d3f76507c988a5bb0c5d Mon Sep 17 00:00:00 2001 From: Dru Sellers Date: Sat, 16 Jun 2018 00:08:30 -0500 Subject: [PATCH] Add in simple regex query support (#319) * Add fst_regex crate in * Reduce API surface area This doesn't need to be public * better test name * Pull Automaton weight out so it can be shared * Implement Regex Query --- Cargo.toml | 1 + src/lib.rs | 1 + src/query/automaton_weight.rs | 57 ++++++++++++++++++++++ src/query/fuzzy_query.rs | 66 +++----------------------- src/query/mod.rs | 3 ++ src/query/regex_query.rs | 89 +++++++++++++++++++++++++++++++++++ 6 files changed, 157 insertions(+), 60 deletions(-) create mode 100644 src/query/automaton_weight.rs create mode 100644 src/query/regex_query.rs diff --git a/Cargo.toml b/Cargo.toml index 8d1c5668c..eb8cbf29d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ lazy_static = "0.2.1" tinysegmenter = "0.1.0" regex = "0.2" fst = {version="0.3", default-features=false} +fst-regex = { version="0.2" } lz4 = {version="1.20", optional=true} snap = {version="0.2"} atomicwrites = {version="0.2.2", optional=true} diff --git a/src/lib.rs b/src/lib.rs index 599a5df6c..93d8626f8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -136,6 +136,7 @@ extern crate combine; extern crate crossbeam; extern crate fnv; extern crate fst; +extern crate fst_regex; extern crate futures; extern crate futures_cpupool; extern crate itertools; diff --git a/src/query/automaton_weight.rs b/src/query/automaton_weight.rs new file mode 100644 index 000000000..c5da69180 --- /dev/null +++ b/src/query/automaton_weight.rs @@ -0,0 +1,57 @@ +use common::BitSet; +use core::SegmentReader; +use fst::Automaton; +use query::BitSetDocSet; +use query::ConstScorer; +use query::{Scorer, Weight}; +use schema::{Field, IndexRecordOption}; +use termdict::{TermDictionary, TermStreamer}; +use Result; + +pub struct AutomatonWeight +where + A: Automaton, +{ + field: Field, + automaton: A, +} + +impl AutomatonWeight +where + A: Automaton, +{ + pub fn new(field: Field, automaton: A) -> AutomatonWeight { + AutomatonWeight { field, automaton } + } + + fn automaton_stream<'a>(&'a self, term_dict: &'a TermDictionary) -> TermStreamer<'a, &'a A> { + let term_stream_builder = term_dict.search(&self.automaton); + term_stream_builder.into_stream() + } +} + +impl Weight for AutomatonWeight +where + A: Automaton, +{ + fn scorer(&self, reader: &SegmentReader) -> Result> { + let max_doc = reader.max_doc(); + let mut doc_bitset = BitSet::with_max_value(max_doc); + + let inverted_index = reader.inverted_index(self.field); + let term_dict = inverted_index.terms(); + let mut term_stream = self.automaton_stream(term_dict); + while term_stream.advance() { + let term_info = term_stream.value(); + let mut block_segment_postings = + inverted_index.read_block_postings_from_terminfo(term_info, IndexRecordOption::Basic); + while block_segment_postings.advance() { + for &doc in block_segment_postings.docs() { + doc_bitset.insert(doc); + } + } + } + let doc_bitset = BitSetDocSet::from(doc_bitset); + Ok(Box::new(ConstScorer::new(doc_bitset))) + } +} diff --git a/src/query/fuzzy_query.rs b/src/query/fuzzy_query.rs index 9354c8f68..aa6de9139 100644 --- a/src/query/fuzzy_query.rs +++ b/src/query/fuzzy_query.rs @@ -1,13 +1,7 @@ -use common::BitSet; -use core::SegmentReader; -use fst::Automaton; use levenshtein_automata::{LevenshteinAutomatonBuilder, DFA}; -use query::BitSetDocSet; -use query::ConstScorer; -use query::{Query, Scorer, Weight}; -use schema::{Field, IndexRecordOption, Term}; +use query::{AutomatonWeight, Query, Weight}; +use schema::Term; use std::collections::HashMap; -use termdict::{TermDictionary, TermStreamer}; use Result; use Searcher; @@ -51,6 +45,7 @@ impl FuzzyTermQuery { } } + /// Creates a new Fuzzy Query that treats transpositions as cost one rather than two pub fn new_prefix(term: Term, distance: u8, transposition_cost_one: bool) -> FuzzyTermQuery { FuzzyTermQuery { term, @@ -60,15 +55,11 @@ impl FuzzyTermQuery { } } - pub fn specialized_weight(&self) -> Result> { + fn specialized_weight(&self) -> Result> { let automaton = LEV_BUILDER.get(&(self.distance, false)) .unwrap() // TODO return an error .build_dfa(self.term.text()); - Ok(AutomatonWeight { - term: self.term.clone(), - field: self.term.field(), - automaton, - }) + Ok(AutomatonWeight::new(self.term.field(), automaton)) } } @@ -78,51 +69,6 @@ impl Query for FuzzyTermQuery { } } -pub struct AutomatonWeight -where - A: Automaton, -{ - term: Term, - field: Field, - automaton: A, -} - -impl AutomatonWeight -where - A: Automaton, -{ - fn automaton_stream<'a>(&'a self, term_dict: &'a TermDictionary) -> TermStreamer<'a, &'a A> { - let term_stream_builder = term_dict.search(&self.automaton); - term_stream_builder.into_stream() - } -} - -impl Weight for AutomatonWeight -where - A: Automaton, -{ - fn scorer(&self, reader: &SegmentReader) -> Result> { - let max_doc = reader.max_doc(); - let mut doc_bitset = BitSet::with_max_value(max_doc); - - let inverted_index = reader.inverted_index(self.field); - let term_dict = inverted_index.terms(); - let mut term_stream = self.automaton_stream(term_dict); - while term_stream.advance() { - let term_info = term_stream.value(); - let mut block_segment_postings = inverted_index - .read_block_postings_from_terminfo(term_info, IndexRecordOption::Basic); - while block_segment_postings.advance() { - for &doc in block_segment_postings.docs() { - doc_bitset.insert(doc); - } - } - } - let doc_bitset = BitSetDocSet::from(doc_bitset); - Ok(Box::new(ConstScorer::new(doc_bitset))) - } -} - #[cfg(test)] mod test { use super::FuzzyTermQuery; @@ -134,7 +80,7 @@ mod test { use Term; #[test] - pub fn test_automaton_weight() { + pub fn test_fuzzy_term() { let mut schema_builder = SchemaBuilder::new(); let country_field = schema_builder.add_text_field("country", TEXT); let schema = schema_builder.build(); diff --git a/src/query/mod.rs b/src/query/mod.rs index 49fb8d296..9838807ac 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -3,6 +3,7 @@ Query */ mod all_query; +mod automaton_weight; mod bitset; mod bm25; mod boolean_query; @@ -14,6 +15,7 @@ mod phrase_query; mod query; mod query_parser; mod range_query; +mod regex_query; mod reqopt_scorer; mod scorer; mod term_query; @@ -32,6 +34,7 @@ pub use self::union::Union; pub use self::vec_docset::VecDocSet; pub use self::all_query::{AllQuery, AllScorer, AllWeight}; +pub use self::automaton_weight::AutomatonWeight; pub use self::bitset::BitSetDocSet; pub use self::boolean_query::BooleanQuery; pub use self::exclude::Exclude; diff --git a/src/query/regex_query.rs b/src/query/regex_query.rs new file mode 100644 index 000000000..6f7209eb3 --- /dev/null +++ b/src/query/regex_query.rs @@ -0,0 +1,89 @@ +use error::ErrorKind; +use fst_regex::Regex; +use query::{AutomatonWeight, Query, Weight}; +use schema::Field; +use std::clone::Clone; +use Result; +use Searcher; + +// A Regex Query matches all of the documents +/// containing a specific term that matches +/// a regex pattern +#[derive(Debug, Clone)] +pub struct RegexQuery { + regex_pattern: String, + field: Field, +} + +impl RegexQuery { + /// Creates a new Fuzzy Query + pub fn new(regex_pattern: String, field: Field) -> RegexQuery { + RegexQuery { + regex_pattern, + field, + } + } + + fn specialized_weight(&self) -> Result> { + let automaton = Regex::new(&self.regex_pattern) + .map_err(|_| ErrorKind::InvalidArgument(self.regex_pattern.clone()))?; + + Ok(AutomatonWeight::new(self.field.clone(), automaton)) + } +} + +impl Query for RegexQuery { + fn weight(&self, _searcher: &Searcher, _scoring_enabled: bool) -> Result> { + Ok(Box::new(self.specialized_weight()?)) + } +} + +#[cfg(test)] +mod test { + use super::RegexQuery; + use collector::TopCollector; + use schema::SchemaBuilder; + use schema::TEXT; + use tests::assert_nearly_equals; + use Index; + + #[test] + pub fn test_regex_query() { + let mut schema_builder = SchemaBuilder::new(); + let country_field = schema_builder.add_text_field("country", TEXT); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + { + let mut index_writer = index.writer_with_num_threads(1, 10_000_000).unwrap(); + index_writer.add_document(doc!( + country_field => "japan", + )); + index_writer.add_document(doc!( + country_field => "korea", + )); + index_writer.commit().unwrap(); + } + index.load_searchers().unwrap(); + let searcher = index.searcher(); + { + let mut collector = TopCollector::with_limit(2); + + let regex_query = RegexQuery::new("jap[ao]n".to_string(), country_field); + searcher.search(®ex_query, &mut collector).unwrap(); + let scored_docs = collector.score_docs(); + assert_eq!(scored_docs.len(), 1, "Expected only 1 document"); + let (score, _) = scored_docs[0]; + assert_nearly_equals(1f32, score); + } + + let searcher = index.searcher(); + { + let mut collector = TopCollector::with_limit(2); + + let regex_query = RegexQuery::new("jap[A-Z]n".to_string(), country_field); + searcher.search(®ex_query, &mut collector).unwrap(); + let scored_docs = collector.score_docs(); + assert_eq!(scored_docs.len(), 0, "Expected ZERO document"); + } + } +}