Add in simple regex query support (#319)

* Add fst_regex crate in

* Reduce API surface area

This doesn't need to be public

* better test name

* Pull Automaton weight out so it can be shared

* Implement Regex Query
This commit is contained in:
Dru Sellers
2018-06-16 00:08:30 -05:00
committed by Paul Masurel
parent 24398d94e4
commit 317baf4e75
6 changed files with 157 additions and 60 deletions

View File

@@ -18,6 +18,7 @@ lazy_static = "0.2.1"
tinysegmenter = "0.1.0"
regex = "0.2"
fst = {version="0.3", default-features=false}
fst-regex = { version="0.2" }
lz4 = {version="1.20", optional=true}
snap = {version="0.2"}
atomicwrites = {version="0.2.2", optional=true}

View File

@@ -136,6 +136,7 @@ extern crate combine;
extern crate crossbeam;
extern crate fnv;
extern crate fst;
extern crate fst_regex;
extern crate futures;
extern crate futures_cpupool;
extern crate itertools;

View File

@@ -0,0 +1,57 @@
use common::BitSet;
use core::SegmentReader;
use fst::Automaton;
use query::BitSetDocSet;
use query::ConstScorer;
use query::{Scorer, Weight};
use schema::{Field, IndexRecordOption};
use termdict::{TermDictionary, TermStreamer};
use Result;
pub struct AutomatonWeight<A>
where
A: Automaton,
{
field: Field,
automaton: A,
}
impl<A> AutomatonWeight<A>
where
A: Automaton,
{
pub fn new(field: Field, automaton: A) -> AutomatonWeight<A> {
AutomatonWeight { field, automaton }
}
fn automaton_stream<'a>(&'a self, term_dict: &'a TermDictionary) -> TermStreamer<'a, &'a A> {
let term_stream_builder = term_dict.search(&self.automaton);
term_stream_builder.into_stream()
}
}
impl<A> Weight for AutomatonWeight<A>
where
A: Automaton,
{
fn scorer(&self, reader: &SegmentReader) -> Result<Box<Scorer>> {
let max_doc = reader.max_doc();
let mut doc_bitset = BitSet::with_max_value(max_doc);
let inverted_index = reader.inverted_index(self.field);
let term_dict = inverted_index.terms();
let mut term_stream = self.automaton_stream(term_dict);
while term_stream.advance() {
let term_info = term_stream.value();
let mut block_segment_postings =
inverted_index.read_block_postings_from_terminfo(term_info, IndexRecordOption::Basic);
while block_segment_postings.advance() {
for &doc in block_segment_postings.docs() {
doc_bitset.insert(doc);
}
}
}
let doc_bitset = BitSetDocSet::from(doc_bitset);
Ok(Box::new(ConstScorer::new(doc_bitset)))
}
}

View File

@@ -1,13 +1,7 @@
use common::BitSet;
use core::SegmentReader;
use fst::Automaton;
use levenshtein_automata::{LevenshteinAutomatonBuilder, DFA};
use query::BitSetDocSet;
use query::ConstScorer;
use query::{Query, Scorer, Weight};
use schema::{Field, IndexRecordOption, Term};
use query::{AutomatonWeight, Query, Weight};
use schema::Term;
use std::collections::HashMap;
use termdict::{TermDictionary, TermStreamer};
use Result;
use Searcher;
@@ -51,6 +45,7 @@ impl FuzzyTermQuery {
}
}
/// Creates a new Fuzzy Query that treats transpositions as cost one rather than two
pub fn new_prefix(term: Term, distance: u8, transposition_cost_one: bool) -> FuzzyTermQuery {
FuzzyTermQuery {
term,
@@ -60,15 +55,11 @@ impl FuzzyTermQuery {
}
}
pub fn specialized_weight(&self) -> Result<AutomatonWeight<DFA>> {
fn specialized_weight(&self) -> Result<AutomatonWeight<DFA>> {
let automaton = LEV_BUILDER.get(&(self.distance, false))
.unwrap() // TODO return an error
.build_dfa(self.term.text());
Ok(AutomatonWeight {
term: self.term.clone(),
field: self.term.field(),
automaton,
})
Ok(AutomatonWeight::new(self.term.field(), automaton))
}
}
@@ -78,51 +69,6 @@ impl Query for FuzzyTermQuery {
}
}
pub struct AutomatonWeight<A>
where
A: Automaton,
{
term: Term,
field: Field,
automaton: A,
}
impl<A> AutomatonWeight<A>
where
A: Automaton,
{
fn automaton_stream<'a>(&'a self, term_dict: &'a TermDictionary) -> TermStreamer<'a, &'a A> {
let term_stream_builder = term_dict.search(&self.automaton);
term_stream_builder.into_stream()
}
}
impl<A> Weight for AutomatonWeight<A>
where
A: Automaton,
{
fn scorer(&self, reader: &SegmentReader) -> Result<Box<Scorer>> {
let max_doc = reader.max_doc();
let mut doc_bitset = BitSet::with_max_value(max_doc);
let inverted_index = reader.inverted_index(self.field);
let term_dict = inverted_index.terms();
let mut term_stream = self.automaton_stream(term_dict);
while term_stream.advance() {
let term_info = term_stream.value();
let mut block_segment_postings = inverted_index
.read_block_postings_from_terminfo(term_info, IndexRecordOption::Basic);
while block_segment_postings.advance() {
for &doc in block_segment_postings.docs() {
doc_bitset.insert(doc);
}
}
}
let doc_bitset = BitSetDocSet::from(doc_bitset);
Ok(Box::new(ConstScorer::new(doc_bitset)))
}
}
#[cfg(test)]
mod test {
use super::FuzzyTermQuery;
@@ -134,7 +80,7 @@ mod test {
use Term;
#[test]
pub fn test_automaton_weight() {
pub fn test_fuzzy_term() {
let mut schema_builder = SchemaBuilder::new();
let country_field = schema_builder.add_text_field("country", TEXT);
let schema = schema_builder.build();

View File

@@ -3,6 +3,7 @@ Query
*/
mod all_query;
mod automaton_weight;
mod bitset;
mod bm25;
mod boolean_query;
@@ -14,6 +15,7 @@ mod phrase_query;
mod query;
mod query_parser;
mod range_query;
mod regex_query;
mod reqopt_scorer;
mod scorer;
mod term_query;
@@ -32,6 +34,7 @@ pub use self::union::Union;
pub use self::vec_docset::VecDocSet;
pub use self::all_query::{AllQuery, AllScorer, AllWeight};
pub use self::automaton_weight::AutomatonWeight;
pub use self::bitset::BitSetDocSet;
pub use self::boolean_query::BooleanQuery;
pub use self::exclude::Exclude;

89
src/query/regex_query.rs Normal file
View File

@@ -0,0 +1,89 @@
use error::ErrorKind;
use fst_regex::Regex;
use query::{AutomatonWeight, Query, Weight};
use schema::Field;
use std::clone::Clone;
use Result;
use Searcher;
// A Regex Query matches all of the documents
/// containing a specific term that matches
/// a regex pattern
#[derive(Debug, Clone)]
pub struct RegexQuery {
regex_pattern: String,
field: Field,
}
impl RegexQuery {
/// Creates a new Fuzzy Query
pub fn new(regex_pattern: String, field: Field) -> RegexQuery {
RegexQuery {
regex_pattern,
field,
}
}
fn specialized_weight(&self) -> Result<AutomatonWeight<Regex>> {
let automaton = Regex::new(&self.regex_pattern)
.map_err(|_| ErrorKind::InvalidArgument(self.regex_pattern.clone()))?;
Ok(AutomatonWeight::new(self.field.clone(), automaton))
}
}
impl Query for RegexQuery {
fn weight(&self, _searcher: &Searcher, _scoring_enabled: bool) -> Result<Box<Weight>> {
Ok(Box::new(self.specialized_weight()?))
}
}
#[cfg(test)]
mod test {
use super::RegexQuery;
use collector::TopCollector;
use schema::SchemaBuilder;
use schema::TEXT;
use tests::assert_nearly_equals;
use Index;
#[test]
pub fn test_regex_query() {
let mut schema_builder = SchemaBuilder::new();
let country_field = schema_builder.add_text_field("country", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
{
let mut index_writer = index.writer_with_num_threads(1, 10_000_000).unwrap();
index_writer.add_document(doc!(
country_field => "japan",
));
index_writer.add_document(doc!(
country_field => "korea",
));
index_writer.commit().unwrap();
}
index.load_searchers().unwrap();
let searcher = index.searcher();
{
let mut collector = TopCollector::with_limit(2);
let regex_query = RegexQuery::new("jap[ao]n".to_string(), country_field);
searcher.search(&regex_query, &mut collector).unwrap();
let scored_docs = collector.score_docs();
assert_eq!(scored_docs.len(), 1, "Expected only 1 document");
let (score, _) = scored_docs[0];
assert_nearly_equals(1f32, score);
}
let searcher = index.searcher();
{
let mut collector = TopCollector::with_limit(2);
let regex_query = RegexQuery::new("jap[A-Z]n".to_string(), country_field);
searcher.search(&regex_query, &mut collector).unwrap();
let scored_docs = collector.score_docs();
assert_eq!(scored_docs.len(), 0, "Expected ZERO document");
}
}
}