diff --git a/Cargo.toml b/Cargo.toml index 4d2999409..f760f3f9d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,6 +42,7 @@ futures-cpupool = "0.1" error-chain = "0.8" owning_ref = "0.3" stable_deref_trait = "1.0.0" +rust-stemmers = "0.1.0" [target.'cfg(windows)'.dependencies] winapi = "0.2" diff --git a/examples/simple_search.rs b/examples/simple_search.rs index 0d35f0e42..afdeb47c6 100644 --- a/examples/simple_search.rs +++ b/examples/simple_search.rs @@ -179,7 +179,7 @@ fn run_example(index_path: &Path) -> tantivy::Result<()> { // Here, if the user does not specify which // field they want to search, tantivy will search // in both title and body. - let query_parser = QueryParser::new(index.schema(), vec![title, body]); + let mut query_parser = QueryParser::new(index.schema(), vec![title, body]); // QueryParser may fail if the query is not in the right // format. For user facing applications, this can be a problem. diff --git a/src/analyzer/analyzer.rs b/src/analyzer/analyzer.rs new file mode 100644 index 000000000..c1a916a1d --- /dev/null +++ b/src/analyzer/analyzer.rs @@ -0,0 +1,72 @@ + + + + +#[derive(Default)] +pub struct Token { + pub offset_from: usize, + pub offset_to: usize, + pub position: usize, + pub term: String, +} + +pub trait Analyzer<'a>: Sized { + + type TokenStreamImpl: TokenStream; + + fn analyze(&mut self, text: &'a str) -> Self::TokenStreamImpl; + + fn filter(self, new_filter: NewFilter) -> ChainAnalyzer + where NewFilter: TokenFilterFactory<>::TokenStreamImpl> { + ChainAnalyzer { + head: new_filter, + tail: self + } + } +} + +pub trait TokenStream { + + fn advance(&mut self) -> bool; + + fn token(&self) -> &Token; + + fn token_mut(&mut self) -> &mut Token; + + fn next(&mut self) -> Option<&Token> { + if self.advance() { + Some(self.token()) + } + else { + None + } + } +} + + +pub struct ChainAnalyzer { + head: HeadTokenFilterFactory, + tail: TailAnalyzer +} + + +impl<'a, HeadTokenFilterFactory, TailAnalyzer> Analyzer<'a> for ChainAnalyzer + where HeadTokenFilterFactory: TokenFilterFactory, + TailAnalyzer: Analyzer<'a> { + + type TokenStreamImpl = HeadTokenFilterFactory::ResultTokenStream; + + fn analyze(&mut self, text: &'a str) -> Self::TokenStreamImpl { + let tail_token_stream = self.tail.analyze(text); + self.head.transform(tail_token_stream) + } +} + + +pub trait TokenFilterFactory { + + type ResultTokenStream: TokenStream; + + fn transform(&self, token_stream: TailTokenStream) -> Self::ResultTokenStream; +} + diff --git a/src/analyzer/lower_caser.rs b/src/analyzer/lower_caser.rs new file mode 100644 index 000000000..dda5f597b --- /dev/null +++ b/src/analyzer/lower_caser.rs @@ -0,0 +1,54 @@ +use super::{TokenFilterFactory, TokenStream, Token}; +use std::ascii::AsciiExt; + +pub struct LowerCaser; + +impl TokenFilterFactory for LowerCaser + where TailTokenStream: TokenStream { + + type ResultTokenStream = LowerCaserTokenStream; + + fn transform(&self, token_stream: TailTokenStream) -> Self::ResultTokenStream { + LowerCaserTokenStream::wrap(token_stream) + } +} + +pub struct LowerCaserTokenStream + where TailTokenStream: TokenStream { + tail: TailTokenStream, +} + +impl TokenStream for LowerCaserTokenStream + where TailTokenStream: TokenStream { + + fn token(&self) -> &Token { + self.tail.token() + } + + fn token_mut(&mut self) -> &mut Token { + self.tail.token_mut() + } + + fn advance(&mut self) -> bool { + if self.tail.advance() { + self.tail.token_mut().term.make_ascii_lowercase(); + return true; + } + else { + return false; + } + } +} + +impl LowerCaserTokenStream + where TailTokenStream: TokenStream { + + + fn wrap(tail: TailTokenStream) -> LowerCaserTokenStream { + LowerCaserTokenStream { + tail: tail, + } + } +} + + diff --git a/src/analyzer/mod.rs b/src/analyzer/mod.rs index cfc20218d..1d2974f29 100644 --- a/src/analyzer/mod.rs +++ b/src/analyzer/mod.rs @@ -1,84 +1,54 @@ extern crate regex; -use std::str::Chars; -use std::ascii::AsciiExt; +mod analyzer; +mod simple_tokenizer; +mod lower_caser; +mod remove_long; +mod stemmer; -pub struct TokenIter<'a> { - chars: Chars<'a>, - term_buffer: String, +pub use self::analyzer::{Analyzer, Token, TokenFilterFactory, TokenStream}; +pub use self::simple_tokenizer::SimpleTokenizer; +pub use self::remove_long::RemoveLongFilter; +pub use self::lower_caser::LowerCaser; +pub use self::stemmer::Stemmer; + + + +pub fn en_analyzer<'a>() -> impl Analyzer<'a> { + SimpleTokenizer + .filter(RemoveLongFilter::limit(20)) + .filter(LowerCaser) } -fn append_char_lowercase(c: char, term_buffer: &mut String) { - term_buffer.push(c.to_ascii_lowercase()); -} +#[cfg(test)] +mod test { + use super::{Analyzer, TokenStream, en_analyzer}; -pub trait StreamingIterator<'a, T> { - fn next(&'a mut self) -> Option; -} - -impl<'a, 'b> TokenIter<'b> { - fn consume_token(&'a mut self) -> Option<&'a str> { - for c in &mut self.chars { - if c.is_alphanumeric() { - append_char_lowercase(c, &mut self.term_buffer); - } else { - break; - } - } - Some(&self.term_buffer) + #[test] + fn test_tokenizer() { + let mut analyzer = en_analyzer(); + let mut terms = analyzer.analyze("hello, happy tax payer!"); + assert_eq!(terms.next().unwrap().term, "hello"); + assert_eq!(terms.next().unwrap().term, "happy"); + assert_eq!(terms.next().unwrap().term, "tax"); + assert_eq!(terms.next().unwrap().term, "payer"); + assert!(terms.next().is_none()); } -} - -impl<'a, 'b> StreamingIterator<'a, &'a str> for TokenIter<'b> { - #[inline] - fn next(&'a mut self) -> Option<&'a str> { - self.term_buffer.clear(); - // skipping non-letter characters. - loop { - match self.chars.next() { - Some(c) => { - if c.is_alphanumeric() { - append_char_lowercase(c, &mut self.term_buffer); - return self.consume_token(); - } - } - None => { - return None; - } - } - } + #[test] + fn test_tokenizer_empty() { + let mut terms = en_analyzer().analyze(""); + assert!(terms.next().is_none()); } -} - -pub struct SimpleTokenizer; -impl SimpleTokenizer { - pub fn tokenize<'a>(&self, text: &'a str) -> TokenIter<'a> { - TokenIter { - term_buffer: String::new(), - chars: text.chars(), - } + #[test] + fn test_tokenizer_cjkchars() { + let mut terms = en_analyzer().analyze("hello,中国人民"); + assert_eq!(terms.next().unwrap().term, "hello"); + assert_eq!(terms.next().unwrap().term, "中国人民"); + assert!(terms.next().is_none()); } + } - -#[test] -fn test_tokenizer() { - let simple_tokenizer = SimpleTokenizer; - let mut term_reader = simple_tokenizer.tokenize("hello, happy tax payer!"); - assert_eq!(term_reader.next().unwrap(), "hello"); - assert_eq!(term_reader.next().unwrap(), "happy"); - assert_eq!(term_reader.next().unwrap(), "tax"); - assert_eq!(term_reader.next().unwrap(), "payer"); - assert_eq!(term_reader.next(), None); -} - - -#[test] -fn test_tokenizer_empty() { - let simple_tokenizer = SimpleTokenizer; - let mut term_reader = simple_tokenizer.tokenize(""); - assert_eq!(term_reader.next(), None); -} diff --git a/src/analyzer/remove_long.rs b/src/analyzer/remove_long.rs new file mode 100644 index 000000000..b4b4b4e0e --- /dev/null +++ b/src/analyzer/remove_long.rs @@ -0,0 +1,74 @@ +use super::{TokenFilterFactory, TokenStream, Token}; + + +pub struct RemoveLongFilter { + length_limit: usize, +} + +impl RemoveLongFilter { + // the limit is in bytes of the UTF-8 representation. + pub fn limit(length_limit: usize) -> RemoveLongFilter { + RemoveLongFilter { + length_limit: length_limit, + } + } +} + +impl RemoveLongFilterStream + where TailTokenStream: TokenStream { + + fn predicate(&self, token: &Token) -> bool { + token.term.len() < self.token_length_limit + } + + fn wrap(token_length_limit: usize, tail: TailTokenStream) -> RemoveLongFilterStream { + RemoveLongFilterStream { + token_length_limit: token_length_limit, + tail: tail, + } + } +} + + +impl TokenFilterFactory for RemoveLongFilter + where TailTokenStream: TokenStream { + + type ResultTokenStream = RemoveLongFilterStream; + + fn transform(&self, token_stream: TailTokenStream) -> Self::ResultTokenStream { + RemoveLongFilterStream::wrap(self.length_limit, token_stream) + } +} + +pub struct RemoveLongFilterStream + where TailTokenStream: TokenStream { + + token_length_limit: usize, + tail: TailTokenStream, +} + +impl TokenStream for RemoveLongFilterStream + where TailTokenStream: TokenStream { + + fn token(&self) -> &Token { + self.tail.token() + } + + fn token_mut(&mut self) -> &mut Token { + self.tail.token_mut() + } + + fn advance(&mut self) -> bool { + loop { + if self.tail.advance() { + if self.predicate(self.tail.token()) { + return true; + } + } + else { + return false; + } + } + } + +} \ No newline at end of file diff --git a/src/analyzer/simple_tokenizer.rs b/src/analyzer/simple_tokenizer.rs new file mode 100644 index 000000000..2d5b27907 --- /dev/null +++ b/src/analyzer/simple_tokenizer.rs @@ -0,0 +1,69 @@ + +use std::str::CharIndices; +use super::{Token, Analyzer, TokenStream}; + +pub struct SimpleTokenizer; + +pub struct SimpleTokenStream<'a> { + text: &'a str, + chars: CharIndices<'a>, + token: Token, +} + +impl<'a> Analyzer<'a> for SimpleTokenizer { + + type TokenStreamImpl = SimpleTokenStream<'a>; + + fn analyze(&mut self, text: &'a str) -> Self::TokenStreamImpl { + SimpleTokenStream { + text: text, + chars: text.char_indices(), + token: Token::default(), + } + } +} + +impl<'a> SimpleTokenStream<'a> { + + fn token_limit(&mut self) -> usize { + (&mut self.chars) + .filter(|&(_, ref c)| !c.is_alphanumeric()) + .map(|(offset, _)| offset) + .next() + .unwrap_or(self.text.len()) + } +} + +impl<'a> TokenStream for SimpleTokenStream<'a> { + + fn advance(&mut self) -> bool { + self.token.term.clear(); + self.token.position += 1; + + loop { + match self.chars.next() { + Some((offset_from, c)) => { + if c.is_alphanumeric() { + let offset_to = self.token_limit(); + self.token.offset_from = offset_from; + self.token.offset_to = offset_to; + self.token.term.push_str(&self.text[offset_from..offset_to]); + return true; + } + } + None => { + return false; + } + } + } + } + + fn token(&self) -> &Token { + &self.token + } + + fn token_mut(&mut self) -> &mut Token { + &mut self.token + } + +} \ No newline at end of file diff --git a/src/analyzer/stemmer.rs b/src/analyzer/stemmer.rs new file mode 100644 index 000000000..4988d8325 --- /dev/null +++ b/src/analyzer/stemmer.rs @@ -0,0 +1,69 @@ +use std::sync::Arc; +use super::{TokenFilterFactory, TokenStream, Token}; +use rust_stemmers::{Algorithm, self}; + +pub struct Stemmer { + stemmer: Arc, +} + +impl Stemmer { + pub fn new() -> Stemmer { + let inner_stemmer = rust_stemmers::Stemmer::create(Algorithm::English); + Stemmer { + stemmer: Arc::new(inner_stemmer), + } + } +} + +impl TokenFilterFactory for Stemmer + where TailTokenStream: TokenStream { + + type ResultTokenStream = StemmerTokenStream; + + fn transform(&self, token_stream: TailTokenStream) -> Self::ResultTokenStream { + StemmerTokenStream::wrap(self.stemmer.clone(), token_stream) + } +} + + +pub struct StemmerTokenStream + where TailTokenStream: TokenStream { + tail: TailTokenStream, + stemmer: Arc, +} + +impl TokenStream for StemmerTokenStream + where TailTokenStream: TokenStream { + + fn token(&self) -> &Token { + self.tail.token() + } + + fn token_mut(&mut self) -> &mut Token { + self.tail.token_mut() + } + + fn advance(&mut self) -> bool { + if self.tail.advance() { + // TODO remove allocation + let stemmed_str: String = self.stemmer.stem(&self.token().term).into_owned(); + self.token_mut().term.clear(); + self.token_mut().term.push_str(&stemmed_str); + true + } + else { + false + } + } +} + +impl StemmerTokenStream + where TailTokenStream: TokenStream { + + fn wrap(stemmer: Arc, tail: TailTokenStream) -> StemmerTokenStream { + StemmerTokenStream { + tail: tail, + stemmer: stemmer, + } + } +} \ No newline at end of file diff --git a/src/collector/facet_collector.rs b/src/collector/facet_collector.rs index 2d760dfc6..983d1ffd6 100644 --- a/src/collector/facet_collector.rs +++ b/src/collector/facet_collector.rs @@ -103,7 +103,7 @@ mod tests { { // perform the query let mut facet_collectors = chain().push(&mut ffvf_i64).push(&mut ffvf_u64); - let query_parser = QueryParser::new(schema, vec![text_field]); + let mut query_parser = QueryParser::new(schema, vec![text_field]); let query = query_parser.parse_query("text:text").unwrap(); query.search(&searcher, &mut facet_collectors).unwrap(); } diff --git a/src/lib.rs b/src/lib.rs index 5d13c8299..592bc414a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -63,6 +63,7 @@ extern crate futures; extern crate futures_cpupool; extern crate owning_ref; extern crate stable_deref_trait; +extern crate rust_stemmers; #[cfg(test)] extern crate env_logger; @@ -98,7 +99,7 @@ mod compression; mod indexer; mod common; mod error; -mod analyzer; +pub mod analyzer; mod datastruct; pub mod termdict; diff --git a/src/postings/postings_writer.rs b/src/postings/postings_writer.rs index 772506bef..812490fff 100644 --- a/src/postings/postings_writer.rs +++ b/src/postings/postings_writer.rs @@ -7,7 +7,7 @@ use postings::Recorder; use analyzer::SimpleTokenizer; use Result; use schema::{Schema, Field}; -use analyzer::StreamingIterator; +use analyzer::{TokenStream, Analyzer}; use std::marker::PhantomData; use std::ops::DerefMut; use datastruct::stacker::{HashMap, Heap}; @@ -155,11 +155,11 @@ pub trait PostingsWriter { let mut term = unsafe { Term::with_capacity(100) }; term.set_field(field); for field_value in field_values { - let mut tokens = SimpleTokenizer.tokenize(field_value.value().text()); + let mut tokens = SimpleTokenizer.analyze(field_value.value().text()); // right now num_tokens and pos are redundant, but it should // change when we get proper analyzers while let Some(token) = tokens.next() { - term.set_text(token); + term.set_text(&token.term); self.suscribe(term_index, doc_id, pos, &term, heap); pos += 1u32; num_tokens += 1u32; diff --git a/src/query/query_parser/query_parser.rs b/src/query/query_parser/query_parser.rs index 0b6b43efe..8c714ec08 100644 --- a/src/query/query_parser/query_parser.rs +++ b/src/query/query_parser/query_parser.rs @@ -8,12 +8,11 @@ use query::Occur; use query::TermQuery; use postings::SegmentPostingsOption; use query::PhraseQuery; -use analyzer::SimpleTokenizer; -use analyzer::StreamingIterator; +use analyzer::{SimpleTokenizer, TokenStream}; use schema::{Term, FieldType}; use std::str::FromStr; use std::num::ParseIntError; - +use analyzer::Analyzer; /// Possible error that may happen when parsing a query. @@ -110,13 +109,13 @@ impl QueryParser { /// /// Implementing a lenient mode for this query parser is tracked /// in [Issue 5](https://github.com/fulmicoton/tantivy/issues/5) - pub fn parse_query(&self, query: &str) -> Result, QueryParserError> { + pub fn parse_query(&mut self, query: &str) -> Result, QueryParserError> { let logical_ast = self.parse_query_to_logical_ast(query)?; Ok(convert_to_query(logical_ast)) } /// Parse the user query into an AST. - fn parse_query_to_logical_ast(&self, query: &str) -> Result { + fn parse_query_to_logical_ast(&mut self, query: &str) -> Result { let (user_input_ast, _remaining) = parse_to_ast(query) .map_err(|_| QueryParserError::SyntaxError)?; self.compute_logical_ast(user_input_ast) @@ -128,7 +127,7 @@ impl QueryParser { .ok_or_else(|| QueryParserError::FieldDoesNotExist(String::from(field_name))) } - fn compute_logical_ast(&self, + fn compute_logical_ast(&mut self, user_input_ast: UserInputAST) -> Result { let (occur, ast) = self.compute_logical_ast_with_occur(user_input_ast)?; @@ -138,7 +137,7 @@ impl QueryParser { Ok(ast) } - fn compute_logical_ast_for_leaf(&self, + fn compute_logical_ast_for_leaf(&mut self, field: Field, phrase: &str) -> Result, QueryParserError> { @@ -163,9 +162,9 @@ impl QueryParser { FieldType::Str(ref str_options) => { let mut terms: Vec = Vec::new(); if str_options.get_indexing_options().is_tokenized() { - let mut token_iter = self.analyzer.tokenize(phrase); + let mut token_iter = self.analyzer.analyze(phrase); while let Some(token) = token_iter.next() { - let term = Term::from_field_text(field, token); + let term = Term::from_field_text(field, &token.term); terms.push(term); } } else { @@ -191,7 +190,7 @@ impl QueryParser { } } - fn compute_logical_ast_with_occur(&self, + fn compute_logical_ast_with_occur(&mut self, user_input_ast: UserInputAST) -> Result<(Occur, LogicalAST), QueryParserError> { match user_input_ast { @@ -341,15 +340,15 @@ mod test { #[test] pub fn test_parse_query_simple() { - let query_parser = make_query_parser(); + let mut query_parser = make_query_parser(); assert!(query_parser.parse_query("toto").is_ok()); } #[test] pub fn test_parse_nonindexed_field_yields_error() { - let query_parser = make_query_parser(); + let mut query_parser = make_query_parser(); - let is_not_indexed_err = |query: &str| { + let mut is_not_indexed_err = |query: &str| { let result: Result, QueryParserError> = query_parser.parse_query(query); if let Err(QueryParserError::FieldNotIndexed(field_name)) = result { Some(field_name.clone()) @@ -377,7 +376,7 @@ mod test { #[test] pub fn test_parse_query_ints() { - let query_parser = make_query_parser(); + let mut query_parser = make_query_parser(); assert!(query_parser.parse_query("signed:2324").is_ok()); assert!(query_parser.parse_query("signed:\"22\"").is_ok()); assert!(query_parser.parse_query("signed:\"-2234\"").is_ok()); diff --git a/src/query/query_parser/stemmer.rs b/src/query/query_parser/stemmer.rs new file mode 100644 index 000000000..a1818950f --- /dev/null +++ b/src/query/query_parser/stemmer.rs @@ -0,0 +1,44 @@ +use std::sync::Arc; +use stemmer; + + +pub struct StemmerTokenStream + where TailTokenStream: TokenStream { + tail: TailTokenStream, + stemmer: Arc, +} + +impl TokenStream for StemmerTokenStream + where TailTokenStream: TokenStream { + + fn token(&self) -> &Token { + self.tail.token() + } + + fn token_mut(&mut self) -> &mut Token { + self.tail.token_mut() + } + + fn advance(&mut self) -> bool { + if self.tail.advance() { + // self.tail.token_mut().term.make_ascii_lowercase(); + let new_str = self.stemmer.stem_str(&self.token().term); + true + } + else { + false + } + } + +} + +impl StemmerTokenStream + where TailTokenStream: TokenStream { + + fn wrap(stemmer: Arc, tail: TailTokenStream) -> StemmerTokenStream { + StemmerTokenStream { + tail: tail, + stemmer: stemmer, + } + } +} \ No newline at end of file