From 2cab111f99043c1c4f4ed74ce751ea4b6d14dc01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Massot?= Date: Sun, 25 Jun 2023 12:36:47 +0200 Subject: [PATCH] Refactor token filter. --- benches/analyzer.rs | 25 +++++++- src/tokenizer/alphanum_only.rs | 21 ++----- src/tokenizer/ascii_folding_filter.rs | 38 ++++-------- src/tokenizer/lower_caser.rs | 36 +++--------- src/tokenizer/remove_long.rs | 25 ++------ src/tokenizer/split_compound_words.rs | 25 ++------ src/tokenizer/stemmer.rs | 25 ++------ src/tokenizer/stop_word_filter/mod.rs | 25 ++------ src/tokenizer/tokenizer.rs | 83 ++++++++++++++++----------- tokenizer-api/src/lib.rs | 26 ++++++++- 10 files changed, 134 insertions(+), 195 deletions(-) diff --git a/benches/analyzer.rs b/benches/analyzer.rs index 7a96fa119..6dfc3117a 100644 --- a/benches/analyzer.rs +++ b/benches/analyzer.rs @@ -1,5 +1,7 @@ use criterion::{criterion_group, criterion_main, Criterion}; -use tantivy::tokenizer::TokenizerManager; +use tantivy::tokenizer::{ + BoxTokenFilter, LowerCaser, RemoveLongFilter, SimpleTokenizer, TextAnalyzer, TokenizerManager, +}; const ALICE_TXT: &str = include_str!("alice.txt"); @@ -16,7 +18,26 @@ pub fn criterion_benchmark(c: &mut Criterion) { assert_eq!(word_count, 30_731); }) }); + let token_filters = vec![ + BoxTokenFilter::from(RemoveLongFilter::limit(40)), + BoxTokenFilter::from(LowerCaser), + ]; + let mut dynamic_analyzer = TextAnalyzer::new(SimpleTokenizer::default(), token_filters); + c.bench_function("default-dynamic-tokenize-alice", |b| { + b.iter(|| { + let mut word_count = 0; + let mut token_stream = dynamic_analyzer.token_stream(ALICE_TXT); + while token_stream.advance() { + word_count += 1; + } + assert_eq!(word_count, 30_731); + }) + }); } -criterion_group!(benches, criterion_benchmark); +criterion_group! { + name = benches; + config = Criterion::default().sample_size(200); + targets = criterion_benchmark +} criterion_main!(benches); diff --git a/src/tokenizer/alphanum_only.rs b/src/tokenizer/alphanum_only.rs index b40731fd3..592575d4d 100644 --- a/src/tokenizer/alphanum_only.rs +++ b/src/tokenizer/alphanum_only.rs @@ -21,7 +21,7 @@ //! // the "emoji" is dropped because its not an alphanum //! assert!(stream.next().is_none()); //! ``` -use super::{Token, TokenFilter, TokenStream, Tokenizer}; +use super::{Token, TokenFilter, TokenStream}; /// `TokenFilter` that removes all tokens that contain non /// ascii alphanumeric characters. @@ -39,23 +39,10 @@ impl AlphaNumOnlyFilterStream { } impl TokenFilter for AlphaNumOnlyFilter { - type Tokenizer = AlphaNumOnlyFilterWrapper; + type OutputTokenStream = AlphaNumOnlyFilterStream; - fn transform(self, tokenizer: T) -> AlphaNumOnlyFilterWrapper { - AlphaNumOnlyFilterWrapper(tokenizer) - } -} - -#[derive(Clone)] -pub struct AlphaNumOnlyFilterWrapper(T); - -impl Tokenizer for AlphaNumOnlyFilterWrapper { - type TokenStream<'a> = AlphaNumOnlyFilterStream>; - - fn token_stream<'a>(&'a mut self, text: &'a str) -> Self::TokenStream<'a> { - AlphaNumOnlyFilterStream { - tail: self.0.token_stream(text), - } + fn filter(&self, token_stream: T) -> Self::OutputTokenStream { + AlphaNumOnlyFilterStream { tail: token_stream } } } diff --git a/src/tokenizer/ascii_folding_filter.rs b/src/tokenizer/ascii_folding_filter.rs index da8039e17..981d09e27 100644 --- a/src/tokenizer/ascii_folding_filter.rs +++ b/src/tokenizer/ascii_folding_filter.rs @@ -1,6 +1,6 @@ use std::mem; -use super::{Token, TokenFilter, TokenStream, Tokenizer}; +use super::{Token, TokenFilter, TokenStream}; /// This class converts alphabetic, numeric, and symbolic Unicode characters /// which are not in the first 127 ASCII characters (the "Basic Latin" Unicode @@ -9,48 +9,30 @@ use super::{Token, TokenFilter, TokenStream, Tokenizer}; pub struct AsciiFoldingFilter; impl TokenFilter for AsciiFoldingFilter { - type Tokenizer = AsciiFoldingFilterWrapper; + type OutputTokenStream = AsciiFoldingFilterTokenStream; - fn transform(self, tokenizer: T) -> AsciiFoldingFilterWrapper { - AsciiFoldingFilterWrapper { - tokenizer, - buffer: String::new(), - } - } -} - -#[derive(Clone)] -pub struct AsciiFoldingFilterWrapper { - tokenizer: T, - buffer: String, -} - -impl Tokenizer for AsciiFoldingFilterWrapper { - type TokenStream<'a> = AsciiFoldingFilterTokenStream<'a, T::TokenStream<'a>>; - - fn token_stream<'a>(&'a mut self, text: &'a str) -> Self::TokenStream<'a> { - self.buffer.clear(); + fn filter(&self, token_stream: T) -> Self::OutputTokenStream { AsciiFoldingFilterTokenStream { - buffer: &mut self.buffer, - tail: self.tokenizer.token_stream(text), + buffer: String::new(), + tail: token_stream, } } } -pub struct AsciiFoldingFilterTokenStream<'a, T> { - buffer: &'a mut String, +pub struct AsciiFoldingFilterTokenStream { + buffer: String, tail: T, } -impl<'a, T: TokenStream> TokenStream for AsciiFoldingFilterTokenStream<'a, T> { +impl<'a, T: TokenStream> TokenStream for AsciiFoldingFilterTokenStream { fn advance(&mut self) -> bool { if !self.tail.advance() { return false; } if !self.token_mut().text.is_ascii() { // ignore its already ascii - to_ascii(&self.tail.token().text, self.buffer); - mem::swap(&mut self.tail.token_mut().text, self.buffer); + to_ascii(&self.tail.token().text, &mut self.buffer); + mem::swap(&mut self.tail.token_mut().text, &mut self.buffer); } true } diff --git a/src/tokenizer/lower_caser.rs b/src/tokenizer/lower_caser.rs index 56792ba82..ab3b3533c 100644 --- a/src/tokenizer/lower_caser.rs +++ b/src/tokenizer/lower_caser.rs @@ -1,42 +1,24 @@ use std::mem; -use super::{Token, TokenFilter, TokenStream, Tokenizer}; +use super::{Token, TokenFilter, TokenStream}; /// Token filter that lowercase terms. #[derive(Clone)] pub struct LowerCaser; impl TokenFilter for LowerCaser { - type Tokenizer = LowerCaserFilter; + type OutputTokenStream = LowerCaserTokenStream; - fn transform(self, tokenizer: T) -> Self::Tokenizer { - LowerCaserFilter { - tokenizer, + fn filter(&self, token_stream: T) -> Self::OutputTokenStream { + LowerCaserTokenStream { + tail: token_stream, buffer: String::new(), } } } -#[derive(Clone)] -pub struct LowerCaserFilter { - tokenizer: T, +pub struct LowerCaserTokenStream { buffer: String, -} - -impl Tokenizer for LowerCaserFilter { - type TokenStream<'a> = LowerCaserTokenStream<'a, T::TokenStream<'a>>; - - fn token_stream<'a>(&'a mut self, text: &'a str) -> Self::TokenStream<'a> { - self.buffer.clear(); - LowerCaserTokenStream { - tail: self.tokenizer.token_stream(text), - buffer: &mut self.buffer, - } - } -} - -pub struct LowerCaserTokenStream<'a, T> { - buffer: &'a mut String, tail: T, } @@ -51,7 +33,7 @@ fn to_lowercase_unicode(text: &str, output: &mut String) { } } -impl<'a, T: TokenStream> TokenStream for LowerCaserTokenStream<'a, T> { +impl TokenStream for LowerCaserTokenStream { fn advance(&mut self) -> bool { if !self.tail.advance() { return false; @@ -60,8 +42,8 @@ impl<'a, T: TokenStream> TokenStream for LowerCaserTokenStream<'a, T> { // fast track for ascii. self.token_mut().text.make_ascii_lowercase(); } else { - to_lowercase_unicode(&self.tail.token().text, self.buffer); - mem::swap(&mut self.tail.token_mut().text, self.buffer); + to_lowercase_unicode(&self.tail.token().text, &mut self.buffer); + mem::swap(&mut self.tail.token_mut().text, &mut self.buffer); } true } diff --git a/src/tokenizer/remove_long.rs b/src/tokenizer/remove_long.rs index 78f3e731a..5342e89e9 100644 --- a/src/tokenizer/remove_long.rs +++ b/src/tokenizer/remove_long.rs @@ -12,7 +12,7 @@ //! assert_eq!(stream.next().unwrap().text, "nice"); //! assert!(stream.next().is_none()); //! ``` -use super::{Token, TokenFilter, TokenStream, Tokenizer}; +use super::{Token, TokenFilter, TokenStream}; /// `RemoveLongFilter` removes tokens that are longer /// than a given number of bytes (in UTF-8 representation). @@ -38,29 +38,12 @@ impl RemoveLongFilterStream { } impl TokenFilter for RemoveLongFilter { - type Tokenizer = RemoveLongFilterWrapper; + type OutputTokenStream = RemoveLongFilterStream; - fn transform(self, tokenizer: T) -> RemoveLongFilterWrapper { - RemoveLongFilterWrapper { - length_limit: self.length_limit, - inner: tokenizer, - } - } -} - -#[derive(Clone)] -pub struct RemoveLongFilterWrapper { - length_limit: usize, - inner: T, -} - -impl Tokenizer for RemoveLongFilterWrapper { - type TokenStream<'a> = RemoveLongFilterStream>; - - fn token_stream<'a>(&'a mut self, text: &'a str) -> Self::TokenStream<'a> { + fn filter(&self, token_stream: T) -> Self::OutputTokenStream { RemoveLongFilterStream { token_length_limit: self.length_limit, - tail: self.inner.token_stream(text), + tail: token_stream, } } } diff --git a/src/tokenizer/split_compound_words.rs b/src/tokenizer/split_compound_words.rs index bcde161cc..678a204fc 100644 --- a/src/tokenizer/split_compound_words.rs +++ b/src/tokenizer/split_compound_words.rs @@ -1,6 +1,6 @@ use aho_corasick::{AhoCorasick, AhoCorasickBuilder, MatchKind}; -use super::{Token, TokenFilter, TokenStream, Tokenizer}; +use super::{Token, TokenFilter, TokenStream}; /// A [`TokenFilter`] which splits compound words into their parts /// based on a given dictionary. @@ -80,29 +80,12 @@ impl SplitCompoundWords { } impl TokenFilter for SplitCompoundWords { - type Tokenizer = SplitCompoundWordsFilter; + type OutputTokenStream = SplitCompoundWordsTokenStream; - fn transform(self, tokenizer: T) -> SplitCompoundWordsFilter { - SplitCompoundWordsFilter { - dict: self.dict, - inner: tokenizer, - } - } -} - -#[derive(Clone)] -pub struct SplitCompoundWordsFilter { - dict: AhoCorasick, - inner: T, -} - -impl Tokenizer for SplitCompoundWordsFilter { - type TokenStream<'a> = SplitCompoundWordsTokenStream>; - - fn token_stream<'a>(&'a mut self, text: &'a str) -> Self::TokenStream<'a> { + fn filter(&self, token_stream: T) -> Self::OutputTokenStream { SplitCompoundWordsTokenStream { dict: self.dict.clone(), - tail: self.inner.token_stream(text), + tail: token_stream, cuts: Vec::new(), parts: Vec::new(), } diff --git a/src/tokenizer/stemmer.rs b/src/tokenizer/stemmer.rs index 4c43b609a..8d7e68776 100644 --- a/src/tokenizer/stemmer.rs +++ b/src/tokenizer/stemmer.rs @@ -4,7 +4,7 @@ use std::mem; use rust_stemmers::{self, Algorithm}; use serde::{Deserialize, Serialize}; -use super::{Token, TokenFilter, TokenStream, Tokenizer}; +use super::{Token, TokenFilter, TokenStream}; /// Available stemmer languages. #[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Copy, Clone)] @@ -81,29 +81,12 @@ impl Default for Stemmer { } impl TokenFilter for Stemmer { - type Tokenizer = StemmerFilter; + type OutputTokenStream = StemmerTokenStream; - fn transform(self, tokenizer: T) -> StemmerFilter { - StemmerFilter { - stemmer_algorithm: self.stemmer_algorithm, - inner: tokenizer, - } - } -} - -#[derive(Clone)] -pub struct StemmerFilter { - stemmer_algorithm: Algorithm, - inner: T, -} - -impl Tokenizer for StemmerFilter { - type TokenStream<'a> = StemmerTokenStream>; - - fn token_stream<'a>(&'a mut self, text: &'a str) -> Self::TokenStream<'a> { + fn filter(&self, token_stream: T) -> Self::OutputTokenStream { let stemmer = rust_stemmers::Stemmer::create(self.stemmer_algorithm); StemmerTokenStream { - tail: self.inner.token_stream(text), + tail: token_stream, stemmer, buffer: String::new(), } diff --git a/src/tokenizer/stop_word_filter/mod.rs b/src/tokenizer/stop_word_filter/mod.rs index 3217af716..bd5ee6425 100644 --- a/src/tokenizer/stop_word_filter/mod.rs +++ b/src/tokenizer/stop_word_filter/mod.rs @@ -21,7 +21,7 @@ use rustc_hash::FxHashSet; #[cfg(feature = "stopwords")] use super::Language; -use super::{Token, TokenFilter, TokenStream, Tokenizer}; +use super::{Token, TokenFilter, TokenStream}; /// `TokenFilter` that removes stop words from a token stream #[derive(Clone)] @@ -72,29 +72,12 @@ impl StopWordFilter { } impl TokenFilter for StopWordFilter { - type Tokenizer = StopWordFilterWrapper; + type OutputTokenStream = StopWordFilterStream; - fn transform(self, tokenizer: T) -> StopWordFilterWrapper { - StopWordFilterWrapper { - words: self.words, - inner: tokenizer, - } - } -} - -#[derive(Clone)] -pub struct StopWordFilterWrapper { - words: Arc>, - inner: T, -} - -impl Tokenizer for StopWordFilterWrapper { - type TokenStream<'a> = StopWordFilterStream>; - - fn token_stream<'a>(&'a mut self, text: &'a str) -> Self::TokenStream<'a> { + fn filter(&self, token_stream: T) -> Self::OutputTokenStream { StopWordFilterStream { words: self.words.clone(), - tail: self.inner.token_stream(text), + tail: token_stream, } } } diff --git a/src/tokenizer/tokenizer.rs b/src/tokenizer/tokenizer.rs index a79802291..a185348b3 100644 --- a/src/tokenizer/tokenizer.rs +++ b/src/tokenizer/tokenizer.rs @@ -1,14 +1,14 @@ use dyn_clone::DynClone; /// The tokenizer module contains all of the tools used to process /// text in `tantivy`. -use tokenizer_api::{TokenFilter, TokenStream, Tokenizer}; +use tokenizer_api::{FilteredTokenizer, TokenFilter, TokenStream, Tokenizer}; use crate::tokenizer::empty_tokenizer::EmptyTokenizer; /// `TextAnalyzer` tokenizes an input text into tokens and modifies the resulting `TokenStream`. -#[derive(Clone)] pub struct TextAnalyzer { tokenizer: Box, + token_filters: Vec, } /// A boxable `Tokenizer`, with its `TokenStream` type erased. @@ -25,32 +25,30 @@ impl BoxableTokenizer for T { dyn_clone::clone_trait_object!(BoxableTokenizer); -/// A boxed `BoxableTokenizer` which is a `Tokenizer` with its `TokenStream` type erased. -#[derive(Clone)] -struct BoxTokenizer(Box); - -impl Tokenizer for BoxTokenizer { - type TokenStream<'a> = Box; - - fn token_stream<'a>(&'a mut self, text: &'a str) -> Self::TokenStream<'a> { - self.0.box_token_stream(text).into() - } -} - /// A boxable `TokenFilter`, with its `Tokenizer` type erased. -trait BoxableTokenFilter: 'static + Send + Sync { - /// Wraps a `BoxedTokenizer` and returns a new one. - fn box_transform(&self, tokenizer: BoxTokenizer) -> BoxTokenizer; +trait BoxableTokenFilter: 'static + Send + Sync + DynClone { + /// Transforms a boxed token stream into a new one. + fn box_transform<'a>( + &self, + token_stream: Box, + ) -> Box; } impl BoxableTokenFilter for T { - fn box_transform(&self, tokenizer: BoxTokenizer) -> BoxTokenizer { - let tokenizer = self.clone().transform(tokenizer); - BoxTokenizer(Box::new(tokenizer)) + fn box_transform<'a>( + &self, + token_stream: Box, + ) -> Box { + Box::new(self.clone().filter(token_stream)) } } -/// A boxed `BoxableTokenFilter` which is a `TokenFilter` with its `Tokenizer` type erased. +dyn_clone::clone_trait_object!(BoxableTokenFilter); + +/// Simple wrapper of `Box`. +/// +/// See [`TokenFilter`] for more information. +#[derive(Clone)] pub struct BoxTokenFilter(Box); impl From for BoxTokenFilter { @@ -59,6 +57,19 @@ impl From for BoxTokenFilter { } } +impl Clone for TextAnalyzer { + fn clone(&self) -> Self { + TextAnalyzer { + tokenizer: self.tokenizer.clone(), + token_filters: self + .token_filters + .iter() + .map(|token_filter| token_filter.clone()) + .collect(), + } + } +} + impl TextAnalyzer { /// Builds a new `TextAnalyzer` given a tokenizer and a vector of `BoxTokenFilter`. /// @@ -71,7 +82,7 @@ impl TextAnalyzer { /// ```rust /// use tantivy::tokenizer::*; /// - /// let en_stem = TextAnalyzer::build( + /// let en_stem = TextAnalyzer::new( /// SimpleTokenizer::default(), /// vec![ /// BoxTokenFilter::from(RemoveLongFilter::limit(40)), @@ -79,27 +90,25 @@ impl TextAnalyzer { /// BoxTokenFilter::from(Stemmer::default()), /// ]); /// ``` - pub fn build( - tokenizer: T, - boxed_token_filters: Vec, - ) -> TextAnalyzer { - let mut boxed_tokenizer = BoxTokenizer(Box::new(tokenizer)); - for filter in boxed_token_filters.into_iter() { - boxed_tokenizer = filter.0.box_transform(boxed_tokenizer); - } + pub fn new(tokenizer: T, token_filters: Vec) -> TextAnalyzer { TextAnalyzer { - tokenizer: boxed_tokenizer.0, + tokenizer: Box::new(tokenizer), + token_filters, } } - /// Create a new TextAnalyzerBuilder + /// Create a new TextAnalyzerBuilder. pub fn builder(tokenizer: T) -> TextAnalyzerBuilder { TextAnalyzerBuilder { tokenizer } } /// Creates a token stream for a given `str`. pub fn token_stream<'a>(&'a mut self, text: &'a str) -> Box { - self.tokenizer.box_token_stream(text) + let mut token_stream = self.tokenizer.box_token_stream(text); + for token_filter in &self.token_filters { + token_stream = token_filter.0.box_transform(token_stream); + } + token_stream } } @@ -134,7 +143,10 @@ impl TextAnalyzerBuilder { /// .filter(Stemmer::default()) /// .build(); /// ``` - pub fn filter(self, token_filter: F) -> TextAnalyzerBuilder> { + pub fn filter( + self, + token_filter: F, + ) -> TextAnalyzerBuilder> { TextAnalyzerBuilder { tokenizer: token_filter.transform(self.tokenizer), } @@ -144,6 +156,7 @@ impl TextAnalyzerBuilder { pub fn build(self) -> TextAnalyzer { TextAnalyzer { tokenizer: Box::new(self.tokenizer), + token_filters: Vec::new(), } } } @@ -168,7 +181,7 @@ mod tests { #[test] fn test_text_analyzer_with_filters_boxed() { - let mut analyzer = TextAnalyzer::build( + let mut analyzer = TextAnalyzer::new( WhitespaceTokenizer::default(), vec![ BoxTokenFilter::from(AlphaNumOnlyFilter), diff --git a/tokenizer-api/src/lib.rs b/tokenizer-api/src/lib.rs index 93defac11..179cbe39e 100644 --- a/tokenizer-api/src/lib.rs +++ b/tokenizer-api/src/lib.rs @@ -115,9 +115,31 @@ pub trait TokenStream { pub trait TokenFilter: 'static + Send + Sync + Clone { /// The Tokenizer type returned by this filter, typically parametrized by the underlying /// Tokenizer. - type Tokenizer: Tokenizer; + type OutputTokenStream: TokenStream; + /// Filter a token stream and returns a new one. + fn filter(&self, token_stream: T) -> Self::OutputTokenStream; /// Wraps a Tokenizer and returns a new one. - fn transform(self, tokenizer: T) -> Self::Tokenizer; + fn transform(self, tokenizer: T) -> FilteredTokenizer { + FilteredTokenizer { + tokenizer, + token_filter: self, + } + } +} + +#[derive(Clone)] +pub struct FilteredTokenizer { + tokenizer: T, + token_filter: F, +} + +impl Tokenizer for FilteredTokenizer { + type TokenStream<'a> = F::OutputTokenStream>; + + fn token_stream<'a>(&'a mut self, text: &'a str) -> Self::TokenStream<'a> { + let token_stream = self.tokenizer.token_stream(text); + self.token_filter.filter(token_stream) + } } #[cfg(test)]