From cfc27c9665f6db1971b4ca4063b058decff3c330 Mon Sep 17 00:00:00 2001 From: Evance Souamoro Date: Thu, 29 Apr 2021 11:49:27 +0000 Subject: [PATCH] add support for more like this query --- src/query/mlt/mlt.rs | 387 +++++++++++++++++++++++++++++++++++++++++ src/query/mlt/mod.rs | 5 + src/query/mlt/query.rs | 284 ++++++++++++++++++++++++++++++ src/query/mod.rs | 2 + 4 files changed, 678 insertions(+) create mode 100644 src/query/mlt/mlt.rs create mode 100644 src/query/mlt/mod.rs create mode 100644 src/query/mlt/query.rs diff --git a/src/query/mlt/mlt.rs b/src/query/mlt/mlt.rs new file mode 100644 index 000000000..9839f61b8 --- /dev/null +++ b/src/query/mlt/mlt.rs @@ -0,0 +1,387 @@ +use std::collections::{BinaryHeap, HashMap}; + +use crate::{ + query::{BooleanQuery, BoostQuery, Occur, Query, TermQuery}, + schema::{Field, FieldType, FieldValue, IndexRecordOption, Term, Value}, + tokenizer::{BoxTokenStream, FacetTokenizer, PreTokenizedStream, Tokenizer}, + DocAddress, Result, Searcher, TantivyError, +}; + +#[derive(Debug, PartialEq)] +struct ScoreTerm { + pub term: Term, + pub score: f32, +} + +impl ScoreTerm { + fn new(term: Term, score: f32) -> Self { + Self { term, score } + } +} + +impl Eq for ScoreTerm {} + +impl PartialOrd for ScoreTerm { + fn partial_cmp(&self, other: &Self) -> Option { + self.score.partial_cmp(&other.score) + } +} + +impl Ord for ScoreTerm { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.partial_cmp(other).unwrap_or(std::cmp::Ordering::Equal) + } +} + +/// A struct used as helper to build [`MoreLikeThisQuery`] +#[derive(Debug, Clone)] +pub struct MoreLikeThis { + /// Ignore words which do not occur in at least this many docs. + pub min_doc_frequency: Option, + /// Ignore words which occur in more than this many docs. + pub max_doc_frequency: Option, + /// Ignore words less frequent than this. + pub min_term_frequency: Option, + /// Don't return a query longer than this. + pub max_query_terms: Option, + /// Ignore words if less than this length. + pub min_word_length: Option, + /// Ignore words if greater than this length. + pub max_word_length: Option, + /// Boost factor to use when boosting the terms + pub boost_factor: Option, + /// Current set of stop words. + pub stop_words: Vec, +} + +impl Default for MoreLikeThis { + fn default() -> Self { + Self { + min_doc_frequency: Some(5), + max_doc_frequency: None, + min_term_frequency: Some(2), + max_query_terms: Some(25), + min_word_length: None, + max_word_length: None, + boost_factor: Some(1.0), + stop_words: vec![], + } + } +} + +impl MoreLikeThis { + /// Creates a [`BooleanQuery`] using a document address to collect + /// the top stored field values. + pub fn query_with_document( + &self, + searcher: &Searcher, + doc_address: DocAddress, + ) -> Result { + let score_terms = self.retrieve_terms_from_doc_address(searcher, doc_address)?; + let query = self.create_query(score_terms); + Ok(query) + } + + /// Creates a [`BooleanQuery`] using a set of field values. + pub fn query_with_document_fields( + &self, + searcher: &Searcher, + doc_fields: &[(Field, Vec)], + ) -> Result { + let score_terms = self.retrieve_terms_from_doc_fields(searcher, doc_fields)?; + let query = self.create_query(score_terms); + Ok(query) + } + + /// Creates a [`BooleanQuery`] from an ascendingly sorted list of ScoreTerm + /// This will map the list of ScoreTerm to a list of [`TermQuery`] and compose a + /// BooleanQuery using that list as sub queries. + fn create_query(&self, score_terms: Vec) -> BooleanQuery { + let best_score = score_terms.first().map_or(1f32, |x| x.score); + let mut queries = Vec::new(); + + for ScoreTerm { term, score } in score_terms { + let mut query: Box = + Box::new(TermQuery::new(term, IndexRecordOption::Basic)); + if let Some(factor) = self.boost_factor { + query = Box::new(BoostQuery::new(query, score * factor / best_score)); + } + queries.push((Occur::Should, query)); + } + BooleanQuery::from(queries) + } + + /// Finds terms for a more-like-this query. + /// doc_address is the address of document from which to find terms. + fn retrieve_terms_from_doc_address( + &self, + searcher: &Searcher, + doc_address: DocAddress, + ) -> Result> { + let doc = searcher.doc(doc_address)?; + let field_to_field_values = doc + .get_sorted_field_values() + .iter() + .map(|(field, values)| { + ( + *field, + values + .iter() + .map(|v| (**v).clone()) + .collect::>(), + ) + }) + .collect::>(); + self.retrieve_terms_from_doc_fields(searcher, &field_to_field_values) + } + + /// Finds terms for a more-like-this query. + /// field_to_field_values is a mapping from field to possible values of taht field. + fn retrieve_terms_from_doc_fields( + &self, + searcher: &Searcher, + field_to_field_values: &[(Field, Vec)], + ) -> Result> { + if field_to_field_values.is_empty() { + return Err(TantivyError::InvalidArgument("Cannot create more like this query on empty field values. The document may not have stored fields".to_string())); + } + + let mut field_to_term_freq_map = HashMap::new(); + for (field, field_values) in field_to_field_values { + self.add_term_frequencies(searcher, *field, field_values, &mut field_to_term_freq_map)?; + } + self.create_score_term(searcher, field_to_term_freq_map) + } + + /// Computes the frequency of values for a field while updating the term frequencies + /// Note: A FieldValue can be made up of multiple terms. + /// We are interested in extracting terms within FieldValue + fn add_term_frequencies( + &self, + searcher: &Searcher, + field: Field, + field_values: &[FieldValue], + term_frequencies: &mut HashMap, + ) -> Result<()> { + let schema = searcher.schema(); + let tokenizer_manager = searcher.index().tokenizers(); + + let field_entry = schema.get_field_entry(field); + if !field_entry.is_indexed() { + return Ok(()); + } + + // extract the raw value, possibly tokenizing & filtering to update the term frequency map + match field_entry.field_type() { + FieldType::HierarchicalFacet(_) => { + let facets: Vec<&str> = field_values + .iter() + .map(|field_value| match *field_value.value() { + Value::Facet(ref facet) => Ok(facet.encoded_str()), + _ => Err(TantivyError::InvalidArgument( + "invalid field value".to_string(), + )), + }) + .collect::>>()?; + for fake_str in facets { + FacetTokenizer.token_stream(fake_str).process(&mut |token| { + if self.is_noise_word(token.text.clone()) { + let term = Term::from_field_text(field, &token.text); + *term_frequencies.entry(term).or_insert(0) += 1; + } + }); + } + } + FieldType::Str(text_options) => { + let mut token_streams: Vec = vec![]; + let mut offsets = vec![]; + let mut total_offset = 0; + + for field_value in field_values { + match field_value.value() { + Value::PreTokStr(tok_str) => { + offsets.push(total_offset); + if let Some(last_token) = tok_str.tokens.last() { + total_offset += last_token.offset_to; + } + token_streams.push(PreTokenizedStream::from(tok_str.clone()).into()); + } + Value::Str(ref text) => { + if let Some(tokenizer) = text_options + .get_indexing_options() + .map(|text_indexing_options| { + text_indexing_options.tokenizer().to_string() + }) + .and_then(|tokenizer_name| tokenizer_manager.get(&tokenizer_name)) + { + offsets.push(total_offset); + total_offset += text.len(); + //let v = text.clone(); + token_streams.push(tokenizer.token_stream(text)); + } + } + _ => (), + } + } + + for mut token_stream in token_streams { + token_stream.process(&mut |token| { + if !self.is_noise_word(token.text.clone()) { + let term = Term::from_field_text(field, &token.text); + *term_frequencies.entry(term).or_insert(0) += 1; + } + }); + } + } + FieldType::U64(_) => { + for field_value in field_values { + let val = field_value + .value() + .u64_value() + .ok_or(TantivyError::InvalidArgument("invalid value".to_string()))?; + if !self.is_noise_word(val.to_string()) { + let term = Term::from_field_u64(field, val); + *term_frequencies.entry(term).or_insert(0) += 1; + } + } + } + FieldType::Date(_) => { + for field_value in field_values { + // TODO: Ask if this is the semantic (timestamp) we want + let val = field_value + .value() + .date_value() + .ok_or(TantivyError::InvalidArgument("invalid value".to_string()))? + .timestamp(); + if !self.is_noise_word(val.to_string()) { + let term = Term::from_field_i64(field, val); + *term_frequencies.entry(term).or_insert(0) += 1; + } + } + } + FieldType::I64(_) => { + for field_value in field_values { + let val = field_value + .value() + .i64_value() + .ok_or(TantivyError::InvalidArgument("invalid value".to_string()))?; + if !self.is_noise_word(val.to_string()) { + let term = Term::from_field_i64(field, val); + *term_frequencies.entry(term).or_insert(0) += 1; + } + } + } + FieldType::F64(_) => { + for field_value in field_values { + let val = field_value + .value() + .f64_value() + .ok_or(TantivyError::InvalidArgument("invalid value".to_string()))?; + if !self.is_noise_word(val.to_string()) { + let term = Term::from_field_f64(field, val); + *term_frequencies.entry(term).or_insert(0) += 1; + } + } + } + _ => {} + } + Ok(()) + } + + /// Determines if the term is likely to be of interest based on "more-like-this" settings + fn is_noise_word(&self, word: String) -> bool { + let word_length = word.len(); + if word_length == 0 { + return true; + } + if self + .min_word_length + .map(|min| word_length < min) + .unwrap_or(false) + { + return true; + } + if self + .max_word_length + .map(|max| word_length > max) + .unwrap_or(false) + { + return true; + } + return self.stop_words.contains(&word); + } + + /// Couputes the score for each term while ignoring not useful terms + fn create_score_term( + &self, + searcher: &Searcher, + per_field_term_frequencies: HashMap, + ) -> Result> { + let mut score_terms = BinaryHeap::new(); + let num_docs = searcher + .segment_readers() + .iter() + .map(|x| x.num_docs() as u64) + .sum::(); + + for (term, term_frequency) in per_field_term_frequencies.into_iter() { + // ignore terms with less than min_term_frequency + if self + .min_term_frequency + .map(|x| term_frequency < x) + .unwrap_or(false) + { + continue; + } + + let doc_freq = searcher.doc_freq(&term)?; + + // ignore terms with less than min_doc_frequency + if self + .min_doc_frequency + .map(|x| doc_freq < x) + .unwrap_or(false) + { + continue; + } + + // ignore terms with more than max_doc_frequency + if self + .max_doc_frequency + .map(|x| doc_freq > x) + .unwrap_or(false) + { + continue; + } + + // ignore terms with zero frequency + if doc_freq == 0 { + continue; + } + + // compute similarity & score + let idf = self.idf(doc_freq, num_docs); + let score = (term_frequency as f32) * idf; + score_terms.push(ScoreTerm::new(term, score)); + } + + // limit ourself to max_query terms. we need to sort so to avoid discarding important terms + let score_terms = if let Some(max_query_terms) = self.max_query_terms { + let max_num_terms = std::cmp::min(max_query_terms, score_terms.len()); + score_terms + .into_sorted_vec() + .into_iter() + .take(max_num_terms) + .collect() + } else { + score_terms.into_sorted_vec() + }; + Ok(score_terms) + } + + /// Computes the similarity + fn idf(&self, doc_freq: u64, doc_count: u64) -> f32 { + let x = ((doc_count - doc_freq) as f32 + 0.5) / (doc_freq as f32 + 0.5); + (1f32 + x).ln() + } +} diff --git a/src/query/mlt/mod.rs b/src/query/mlt/mod.rs new file mode 100644 index 000000000..97f541908 --- /dev/null +++ b/src/query/mlt/mod.rs @@ -0,0 +1,5 @@ +mod mlt; +mod query; + +pub use self::mlt::MoreLikeThis; +pub use self::query::{MoreLikeThisQuery}; diff --git a/src/query/mlt/query.rs b/src/query/mlt/query.rs new file mode 100644 index 000000000..3bc5cc833 --- /dev/null +++ b/src/query/mlt/query.rs @@ -0,0 +1,284 @@ +use super::MoreLikeThis; + +use crate::{ + query::{Query, Weight}, + schema::{Field, FieldValue}, + DocAddress, Result, Searcher, TantivyError, +}; + +/// A query that matches all of the documents similar to a document +/// or a set of field values provided. +/// +/// # Examples +/// +/// ``` +/// use tantivy::DocAddress; +/// use tantivy::query::MoreLikeThisQuery; +/// +/// let query = MoreLikeThisQuery::builder() +/// .with_min_doc_frequency(1) +/// .with_max_doc_frequency(10) +/// .with_min_term_frequency(1) +/// .with_min_word_length(2) +/// .with_max_word_length(5) +/// .with_boost_factor(1.0) +/// .with_stop_words(vec!["for".to_string()]) +/// .with_document(DocAddress::new(2, 1)); +/// +/// ``` +#[derive(Debug, Clone)] +pub struct MoreLikeThisQuery { + mlt: MoreLikeThis, + doc_address: Option, + doc_fields: Option)>>, +} + +impl MoreLikeThisQuery { + /// Creates a new builder. + pub fn builder() -> MoreLikeThisQueryBuilder { + MoreLikeThisQueryBuilder::default() + } +} + +impl Query for MoreLikeThisQuery { + fn weight(&self, searcher: &Searcher, scoring_enabled: bool) -> Result> { + if let Some(doc_address) = self.doc_address { + return self + .mlt + .query_with_document(searcher, doc_address)? + .weight(searcher, scoring_enabled); + } + + if let Some(ref doc_fields) = self.doc_fields { + return self + .mlt + .query_with_document_fields(searcher, doc_fields)? + .weight(searcher, scoring_enabled); + } + + Err(TantivyError::InvalidArgument("".to_string())) + } +} + +/// The builder for more-like-this query +#[derive(Debug, Clone)] +pub struct MoreLikeThisQueryBuilder { + mlt: MoreLikeThis, +} + +impl Default for MoreLikeThisQueryBuilder { + fn default() -> Self { + Self { + mlt: MoreLikeThis::default(), + } + } +} + +impl MoreLikeThisQueryBuilder { + /// Sets the minimum document frequency. + /// + /// The resulting query will ignore words which do not occur + /// in at least this many docs. + pub fn with_min_doc_frequency(mut self, value: u64) -> Self { + self.mlt.min_doc_frequency = Some(value); + self + } + + /// Sets the maximum document frequency. + /// + /// The resulting query will ignore words which occur + /// in more than this many docs. + pub fn with_max_doc_frequency(mut self, value: u64) -> Self { + self.mlt.max_doc_frequency = Some(value); + self + } + + /// Sets the minimum term frequency. + /// + /// The resulting query will ignore words less + /// frequent that this number. + pub fn with_min_term_frequency(mut self, value: usize) -> Self { + self.mlt.min_term_frequency = Some(value); + self + } + + /// Sets the maximum query terms. + /// + /// The resulting query will not return a query with more clause than this. + pub fn with_max_query_terms(mut self, value: usize) -> Self { + self.mlt.max_query_terms = Some(value); + self + } + + /// Sets the minimum word length. + /// + /// The resulting query will ignore words shorter than this length. + pub fn with_min_word_length(mut self, value: usize) -> Self { + self.mlt.min_word_length = Some(value); + self + } + + /// Sets the maximum word length. + /// + /// The resulting query will ignore words longer than this length. + pub fn with_max_word_length(mut self, value: usize) -> Self { + self.mlt.max_word_length = Some(value); + self + } + + /// Sets the boost factor + /// + /// The boost factor used by the resulting query for boosting terms. + pub fn with_boost_factor(mut self, value: f32) -> Self { + self.mlt.boost_factor = Some(value); + self + } + + /// Sets the set of stop words + /// + /// The resulting query will ignore these set of words. + pub fn with_stop_words(mut self, value: Vec) -> Self { + self.mlt.stop_words = value; + self + } + + /// Sets the document address + /// Returns the constructed [`MoreLikeThisQuery`] + /// + /// This document will be used to collect field values, extract frequent terms + /// needed for composing the query. + /// + /// Note that field values will only be collected from stored fields in the index. + /// You can construct your own field values from any source. + pub fn with_document(self, doc_address: DocAddress) -> MoreLikeThisQuery { + MoreLikeThisQuery { + mlt: self.mlt, + doc_address: Some(doc_address), + doc_fields: None, + } + } + + /// Sets the document fields + /// Returns the constructed [`MoreLikeThisQuery`] + /// + /// This represents the list field values possibly collected from multiple documents + /// that will be used to compose the resulting query. + /// This interface is meant to be used when you want to provide your own set of fields + /// not necessarily from a specific document. + pub fn with_document_fields( + self, + doc_fields: Vec<(Field, Vec)>, + ) -> MoreLikeThisQuery { + MoreLikeThisQuery { + mlt: self.mlt, + doc_address: None, + doc_fields: Some(doc_fields), + } + } +} + +#[cfg(test)] +mod tests { + use super::MoreLikeThisQuery; + use crate::collector::TopDocs; + use crate::schema::{Schema, STORED, TEXT}; + use crate::DocAddress; + use crate::Index; + + fn create_test_index() -> Index { + let mut schema_builder = Schema::builder(); + let title = schema_builder.add_text_field("title", TEXT); + let body = schema_builder.add_text_field("body", TEXT | STORED); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut index_writer = index.writer_for_tests().unwrap(); + index_writer.add_document(doc!(title => "aaa", body => "the old man and the sea")); + index_writer.add_document(doc!(title => "bbb", body => "an old man sailing on the sea")); + index_writer.add_document(doc!(title => "ccc", body=> "send this message to alice")); + index_writer.add_document(doc!(title => "ddd", body=> "a lady was riding and old bike")); + index_writer.add_document(doc!(title => "eee", body=> "Yes, my lady.")); + index_writer.commit().unwrap(); + index + } + + #[test] + fn test_more_like_this_query_builder() { + // default settings + let query = MoreLikeThisQuery::builder().with_document_fields(vec![]); + + assert_eq!(query.mlt.min_doc_frequency, Some(5)); + assert_eq!(query.mlt.max_doc_frequency, None); + assert_eq!(query.mlt.min_term_frequency, Some(2)); + assert_eq!(query.mlt.max_query_terms, Some(25)); + assert_eq!(query.mlt.min_word_length, None); + assert_eq!(query.mlt.max_word_length, None); + assert_eq!(query.mlt.boost_factor, Some(1.0)); + assert_eq!(query.mlt.stop_words, Vec::::new()); + assert_eq!(query.doc_fields, Some(vec![])); + assert_eq!(query.doc_address, None); + + // custom settings + let query = MoreLikeThisQuery::builder() + .with_min_doc_frequency(2) + .with_max_doc_frequency(5) + .with_min_term_frequency(2) + .with_min_word_length(2) + .with_max_word_length(4) + .with_boost_factor(0.5) + .with_stop_words(vec!["all".to_string(), "for".to_string()]) + .with_document(DocAddress::new(1, 2)); + + assert_eq!(query.mlt.min_doc_frequency, Some(2)); + assert_eq!(query.mlt.max_doc_frequency, Some(5)); + assert_eq!(query.mlt.min_term_frequency, Some(2)); + assert_eq!(query.mlt.min_word_length, Some(2)); + assert_eq!(query.mlt.max_word_length, Some(4)); + assert_eq!(query.mlt.boost_factor, Some(0.5)); + assert_eq!( + query.mlt.stop_words, + vec!["all".to_string(), "for".to_string()] + ); + assert_eq!(query.doc_fields, None); + assert_eq!(query.doc_address, Some(DocAddress::new(1, 2))); + } + + #[test] + fn test_more_like_this_query() { + let index = create_test_index(); + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + + // search base 1st doc with words [sea, and] skipping [old] + let query = MoreLikeThisQuery::builder() + .with_min_doc_frequency(1) + .with_max_doc_frequency(10) + .with_min_term_frequency(1) + .with_min_word_length(2) + .with_max_word_length(5) + .with_boost_factor(1.0) + .with_stop_words(vec!["old".to_string()]) + .with_document(DocAddress::new(0, 0)); + let top_docs = searcher.search(&query, &TopDocs::with_limit(5)).unwrap(); + let mut doc_ids: Vec<_> = top_docs.iter().map(|item| item.1.doc_id).collect(); + doc_ids.sort(); + + assert_eq!(doc_ids.len(), 3); + assert_eq!(doc_ids, vec![0, 1, 3]); + + // search base 5th doc with words [lady] + let query = MoreLikeThisQuery::builder() + .with_min_doc_frequency(1) + .with_max_doc_frequency(10) + .with_min_term_frequency(1) + .with_min_word_length(2) + .with_max_word_length(5) + .with_boost_factor(1.0) + .with_document(DocAddress::new(0, 4)); + let top_docs = searcher.search(&query, &TopDocs::with_limit(5)).unwrap(); + let mut doc_ids: Vec<_> = top_docs.iter().map(|item| item.1.doc_id).collect(); + doc_ids.sort(); + + assert_eq!(doc_ids.len(), 2); + assert_eq!(doc_ids, vec![3, 4]); + } +} diff --git a/src/query/mod.rs b/src/query/mod.rs index bc8e517bf..9d02540e8 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -11,6 +11,7 @@ mod exclude; mod explanation; mod fuzzy_query; mod intersection; +mod mlt; mod phrase_query; mod query; mod query_parser; @@ -45,6 +46,7 @@ pub use self::explanation::Explanation; pub(crate) use self::fuzzy_query::DfaWrapper; pub use self::fuzzy_query::FuzzyTermQuery; pub use self::intersection::intersect_scorers; +pub use self::mlt::MoreLikeThisQuery; pub use self::phrase_query::PhraseQuery; pub use self::query::{Query, QueryClone}; pub use self::query_parser::QueryParser;