diff --git a/Cargo.toml b/Cargo.toml
index 8d1c5668c..eb8cbf29d 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -18,6 +18,7 @@ lazy_static = "0.2.1"
tinysegmenter = "0.1.0"
regex = "0.2"
fst = {version="0.3", default-features=false}
+fst-regex = { version="0.2" }
lz4 = {version="1.20", optional=true}
snap = {version="0.2"}
atomicwrites = {version="0.2.2", optional=true}
diff --git a/src/lib.rs b/src/lib.rs
index 599a5df6c..93d8626f8 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -136,6 +136,7 @@ extern crate combine;
extern crate crossbeam;
extern crate fnv;
extern crate fst;
+extern crate fst_regex;
extern crate futures;
extern crate futures_cpupool;
extern crate itertools;
diff --git a/src/query/automaton_weight.rs b/src/query/automaton_weight.rs
new file mode 100644
index 000000000..c5da69180
--- /dev/null
+++ b/src/query/automaton_weight.rs
@@ -0,0 +1,57 @@
+use common::BitSet;
+use core::SegmentReader;
+use fst::Automaton;
+use query::BitSetDocSet;
+use query::ConstScorer;
+use query::{Scorer, Weight};
+use schema::{Field, IndexRecordOption};
+use termdict::{TermDictionary, TermStreamer};
+use Result;
+
+pub struct AutomatonWeight
+where
+ A: Automaton,
+{
+ field: Field,
+ automaton: A,
+}
+
+impl AutomatonWeight
+where
+ A: Automaton,
+{
+ pub fn new(field: Field, automaton: A) -> AutomatonWeight {
+ AutomatonWeight { field, automaton }
+ }
+
+ fn automaton_stream<'a>(&'a self, term_dict: &'a TermDictionary) -> TermStreamer<'a, &'a A> {
+ let term_stream_builder = term_dict.search(&self.automaton);
+ term_stream_builder.into_stream()
+ }
+}
+
+impl Weight for AutomatonWeight
+where
+ A: Automaton,
+{
+ fn scorer(&self, reader: &SegmentReader) -> Result> {
+ let max_doc = reader.max_doc();
+ let mut doc_bitset = BitSet::with_max_value(max_doc);
+
+ let inverted_index = reader.inverted_index(self.field);
+ let term_dict = inverted_index.terms();
+ let mut term_stream = self.automaton_stream(term_dict);
+ while term_stream.advance() {
+ let term_info = term_stream.value();
+ let mut block_segment_postings =
+ inverted_index.read_block_postings_from_terminfo(term_info, IndexRecordOption::Basic);
+ while block_segment_postings.advance() {
+ for &doc in block_segment_postings.docs() {
+ doc_bitset.insert(doc);
+ }
+ }
+ }
+ let doc_bitset = BitSetDocSet::from(doc_bitset);
+ Ok(Box::new(ConstScorer::new(doc_bitset)))
+ }
+}
diff --git a/src/query/fuzzy_query.rs b/src/query/fuzzy_query.rs
index 9354c8f68..aa6de9139 100644
--- a/src/query/fuzzy_query.rs
+++ b/src/query/fuzzy_query.rs
@@ -1,13 +1,7 @@
-use common::BitSet;
-use core::SegmentReader;
-use fst::Automaton;
use levenshtein_automata::{LevenshteinAutomatonBuilder, DFA};
-use query::BitSetDocSet;
-use query::ConstScorer;
-use query::{Query, Scorer, Weight};
-use schema::{Field, IndexRecordOption, Term};
+use query::{AutomatonWeight, Query, Weight};
+use schema::Term;
use std::collections::HashMap;
-use termdict::{TermDictionary, TermStreamer};
use Result;
use Searcher;
@@ -51,6 +45,7 @@ impl FuzzyTermQuery {
}
}
+ /// Creates a new Fuzzy Query that treats transpositions as cost one rather than two
pub fn new_prefix(term: Term, distance: u8, transposition_cost_one: bool) -> FuzzyTermQuery {
FuzzyTermQuery {
term,
@@ -60,15 +55,11 @@ impl FuzzyTermQuery {
}
}
- pub fn specialized_weight(&self) -> Result> {
+ fn specialized_weight(&self) -> Result> {
let automaton = LEV_BUILDER.get(&(self.distance, false))
.unwrap() // TODO return an error
.build_dfa(self.term.text());
- Ok(AutomatonWeight {
- term: self.term.clone(),
- field: self.term.field(),
- automaton,
- })
+ Ok(AutomatonWeight::new(self.term.field(), automaton))
}
}
@@ -78,51 +69,6 @@ impl Query for FuzzyTermQuery {
}
}
-pub struct AutomatonWeight
-where
- A: Automaton,
-{
- term: Term,
- field: Field,
- automaton: A,
-}
-
-impl AutomatonWeight
-where
- A: Automaton,
-{
- fn automaton_stream<'a>(&'a self, term_dict: &'a TermDictionary) -> TermStreamer<'a, &'a A> {
- let term_stream_builder = term_dict.search(&self.automaton);
- term_stream_builder.into_stream()
- }
-}
-
-impl Weight for AutomatonWeight
-where
- A: Automaton,
-{
- fn scorer(&self, reader: &SegmentReader) -> Result> {
- let max_doc = reader.max_doc();
- let mut doc_bitset = BitSet::with_max_value(max_doc);
-
- let inverted_index = reader.inverted_index(self.field);
- let term_dict = inverted_index.terms();
- let mut term_stream = self.automaton_stream(term_dict);
- while term_stream.advance() {
- let term_info = term_stream.value();
- let mut block_segment_postings = inverted_index
- .read_block_postings_from_terminfo(term_info, IndexRecordOption::Basic);
- while block_segment_postings.advance() {
- for &doc in block_segment_postings.docs() {
- doc_bitset.insert(doc);
- }
- }
- }
- let doc_bitset = BitSetDocSet::from(doc_bitset);
- Ok(Box::new(ConstScorer::new(doc_bitset)))
- }
-}
-
#[cfg(test)]
mod test {
use super::FuzzyTermQuery;
@@ -134,7 +80,7 @@ mod test {
use Term;
#[test]
- pub fn test_automaton_weight() {
+ pub fn test_fuzzy_term() {
let mut schema_builder = SchemaBuilder::new();
let country_field = schema_builder.add_text_field("country", TEXT);
let schema = schema_builder.build();
diff --git a/src/query/mod.rs b/src/query/mod.rs
index 49fb8d296..9838807ac 100644
--- a/src/query/mod.rs
+++ b/src/query/mod.rs
@@ -3,6 +3,7 @@ Query
*/
mod all_query;
+mod automaton_weight;
mod bitset;
mod bm25;
mod boolean_query;
@@ -14,6 +15,7 @@ mod phrase_query;
mod query;
mod query_parser;
mod range_query;
+mod regex_query;
mod reqopt_scorer;
mod scorer;
mod term_query;
@@ -32,6 +34,7 @@ pub use self::union::Union;
pub use self::vec_docset::VecDocSet;
pub use self::all_query::{AllQuery, AllScorer, AllWeight};
+pub use self::automaton_weight::AutomatonWeight;
pub use self::bitset::BitSetDocSet;
pub use self::boolean_query::BooleanQuery;
pub use self::exclude::Exclude;
diff --git a/src/query/regex_query.rs b/src/query/regex_query.rs
new file mode 100644
index 000000000..6f7209eb3
--- /dev/null
+++ b/src/query/regex_query.rs
@@ -0,0 +1,89 @@
+use error::ErrorKind;
+use fst_regex::Regex;
+use query::{AutomatonWeight, Query, Weight};
+use schema::Field;
+use std::clone::Clone;
+use Result;
+use Searcher;
+
+// A Regex Query matches all of the documents
+/// containing a specific term that matches
+/// a regex pattern
+#[derive(Debug, Clone)]
+pub struct RegexQuery {
+ regex_pattern: String,
+ field: Field,
+}
+
+impl RegexQuery {
+ /// Creates a new Fuzzy Query
+ pub fn new(regex_pattern: String, field: Field) -> RegexQuery {
+ RegexQuery {
+ regex_pattern,
+ field,
+ }
+ }
+
+ fn specialized_weight(&self) -> Result> {
+ let automaton = Regex::new(&self.regex_pattern)
+ .map_err(|_| ErrorKind::InvalidArgument(self.regex_pattern.clone()))?;
+
+ Ok(AutomatonWeight::new(self.field.clone(), automaton))
+ }
+}
+
+impl Query for RegexQuery {
+ fn weight(&self, _searcher: &Searcher, _scoring_enabled: bool) -> Result> {
+ Ok(Box::new(self.specialized_weight()?))
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::RegexQuery;
+ use collector::TopCollector;
+ use schema::SchemaBuilder;
+ use schema::TEXT;
+ use tests::assert_nearly_equals;
+ use Index;
+
+ #[test]
+ pub fn test_regex_query() {
+ let mut schema_builder = SchemaBuilder::new();
+ let country_field = schema_builder.add_text_field("country", TEXT);
+ let schema = schema_builder.build();
+ let index = Index::create_in_ram(schema);
+ {
+ let mut index_writer = index.writer_with_num_threads(1, 10_000_000).unwrap();
+ index_writer.add_document(doc!(
+ country_field => "japan",
+ ));
+ index_writer.add_document(doc!(
+ country_field => "korea",
+ ));
+ index_writer.commit().unwrap();
+ }
+ index.load_searchers().unwrap();
+ let searcher = index.searcher();
+ {
+ let mut collector = TopCollector::with_limit(2);
+
+ let regex_query = RegexQuery::new("jap[ao]n".to_string(), country_field);
+ searcher.search(®ex_query, &mut collector).unwrap();
+ let scored_docs = collector.score_docs();
+ assert_eq!(scored_docs.len(), 1, "Expected only 1 document");
+ let (score, _) = scored_docs[0];
+ assert_nearly_equals(1f32, score);
+ }
+
+ let searcher = index.searcher();
+ {
+ let mut collector = TopCollector::with_limit(2);
+
+ let regex_query = RegexQuery::new("jap[A-Z]n".to_string(), country_field);
+ searcher.search(®ex_query, &mut collector).unwrap();
+ let scored_docs = collector.score_docs();
+ assert_eq!(scored_docs.len(), 0, "Expected ZERO document");
+ }
+ }
+}