diff --git a/tokenizer-api/src/lib.rs b/tokenizer-api/src/lib.rs index 93f3a5714..7a81ff9d3 100644 --- a/tokenizer-api/src/lib.rs +++ b/tokenizer-api/src/lib.rs @@ -157,6 +157,77 @@ pub trait TokenFilter: 'static + Send + Sync { fn transform(self, tokenizer: T) -> Self::Tokenizer; } +/// An optional [`TokenFilter`]. +impl TokenFilter for Option { + type Tokenizer = OptionalTokenizer, T>; + + #[inline] + fn transform(self, tokenizer: T) -> Self::Tokenizer { + match self { + Some(filter) => OptionalTokenizer::Enabled(filter.transform(tokenizer)), + None => OptionalTokenizer::Disabled(tokenizer), + } + } +} + +/// A [`Tokenizer`] derived from a [`TokenFilter::transform`] on an +/// [`Option`] token filter. +#[derive(Clone)] +pub enum OptionalTokenizer { + Enabled(E), + Disabled(D), +} + +impl Tokenizer for OptionalTokenizer { + type TokenStream<'a> = OptionalTokenStream, D::TokenStream<'a>>; + + #[inline] + fn token_stream<'a>(&'a mut self, text: &'a str) -> Self::TokenStream<'a> { + match self { + Self::Enabled(tokenizer) => { + let token_stream = tokenizer.token_stream(text); + OptionalTokenStream::Enabled(token_stream) + } + Self::Disabled(tokenizer) => { + let token_stream = tokenizer.token_stream(text); + OptionalTokenStream::Disabled(token_stream) + } + } + } +} + +/// A [`TokenStream`] derived from a [`Tokenizer::token_stream`] on an [`OptionalTokenizer`]. +pub enum OptionalTokenStream { + Enabled(E), + Disabled(D), +} + +impl TokenStream for OptionalTokenStream { + #[inline] + fn advance(&mut self) -> bool { + match self { + Self::Enabled(t) => t.advance(), + Self::Disabled(t) => t.advance(), + } + } + + #[inline] + fn token(&self) -> &Token { + match self { + Self::Enabled(t) => t.token(), + Self::Disabled(t) => t.token(), + } + } + + #[inline] + fn token_mut(&mut self) -> &mut Token { + match self { + Self::Enabled(t) => t.token_mut(), + Self::Disabled(t) => t.token_mut(), + } + } +} + #[cfg(test)] mod test { use super::*;