add support for more like this query

This commit is contained in:
Evance Souamoro
2021-04-29 11:49:27 +00:00
parent 2b4b16ae90
commit cfc27c9665
4 changed files with 678 additions and 0 deletions

387
src/query/mlt/mlt.rs Normal file
View File

@@ -0,0 +1,387 @@
use std::collections::{BinaryHeap, HashMap};
use crate::{
query::{BooleanQuery, BoostQuery, Occur, Query, TermQuery},
schema::{Field, FieldType, FieldValue, IndexRecordOption, Term, Value},
tokenizer::{BoxTokenStream, FacetTokenizer, PreTokenizedStream, Tokenizer},
DocAddress, Result, Searcher, TantivyError,
};
#[derive(Debug, PartialEq)]
struct ScoreTerm {
pub term: Term,
pub score: f32,
}
impl ScoreTerm {
fn new(term: Term, score: f32) -> Self {
Self { term, score }
}
}
impl Eq for ScoreTerm {}
impl PartialOrd for ScoreTerm {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.score.partial_cmp(&other.score)
}
}
impl Ord for ScoreTerm {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.partial_cmp(other).unwrap_or(std::cmp::Ordering::Equal)
}
}
/// A struct used as helper to build [`MoreLikeThisQuery`]
#[derive(Debug, Clone)]
pub struct MoreLikeThis {
/// Ignore words which do not occur in at least this many docs.
pub min_doc_frequency: Option<u64>,
/// Ignore words which occur in more than this many docs.
pub max_doc_frequency: Option<u64>,
/// Ignore words less frequent than this.
pub min_term_frequency: Option<usize>,
/// Don't return a query longer than this.
pub max_query_terms: Option<usize>,
/// Ignore words if less than this length.
pub min_word_length: Option<usize>,
/// Ignore words if greater than this length.
pub max_word_length: Option<usize>,
/// Boost factor to use when boosting the terms
pub boost_factor: Option<f32>,
/// Current set of stop words.
pub stop_words: Vec<String>,
}
impl Default for MoreLikeThis {
fn default() -> Self {
Self {
min_doc_frequency: Some(5),
max_doc_frequency: None,
min_term_frequency: Some(2),
max_query_terms: Some(25),
min_word_length: None,
max_word_length: None,
boost_factor: Some(1.0),
stop_words: vec![],
}
}
}
impl MoreLikeThis {
/// Creates a [`BooleanQuery`] using a document address to collect
/// the top stored field values.
pub fn query_with_document(
&self,
searcher: &Searcher,
doc_address: DocAddress,
) -> Result<BooleanQuery> {
let score_terms = self.retrieve_terms_from_doc_address(searcher, doc_address)?;
let query = self.create_query(score_terms);
Ok(query)
}
/// Creates a [`BooleanQuery`] using a set of field values.
pub fn query_with_document_fields(
&self,
searcher: &Searcher,
doc_fields: &[(Field, Vec<FieldValue>)],
) -> Result<BooleanQuery> {
let score_terms = self.retrieve_terms_from_doc_fields(searcher, doc_fields)?;
let query = self.create_query(score_terms);
Ok(query)
}
/// Creates a [`BooleanQuery`] from an ascendingly sorted list of ScoreTerm
/// This will map the list of ScoreTerm to a list of [`TermQuery`] and compose a
/// BooleanQuery using that list as sub queries.
fn create_query(&self, score_terms: Vec<ScoreTerm>) -> BooleanQuery {
let best_score = score_terms.first().map_or(1f32, |x| x.score);
let mut queries = Vec::new();
for ScoreTerm { term, score } in score_terms {
let mut query: Box<dyn Query> =
Box::new(TermQuery::new(term, IndexRecordOption::Basic));
if let Some(factor) = self.boost_factor {
query = Box::new(BoostQuery::new(query, score * factor / best_score));
}
queries.push((Occur::Should, query));
}
BooleanQuery::from(queries)
}
/// Finds terms for a more-like-this query.
/// doc_address is the address of document from which to find terms.
fn retrieve_terms_from_doc_address(
&self,
searcher: &Searcher,
doc_address: DocAddress,
) -> Result<Vec<ScoreTerm>> {
let doc = searcher.doc(doc_address)?;
let field_to_field_values = doc
.get_sorted_field_values()
.iter()
.map(|(field, values)| {
(
*field,
values
.iter()
.map(|v| (**v).clone())
.collect::<Vec<FieldValue>>(),
)
})
.collect::<Vec<_>>();
self.retrieve_terms_from_doc_fields(searcher, &field_to_field_values)
}
/// Finds terms for a more-like-this query.
/// field_to_field_values is a mapping from field to possible values of taht field.
fn retrieve_terms_from_doc_fields(
&self,
searcher: &Searcher,
field_to_field_values: &[(Field, Vec<FieldValue>)],
) -> Result<Vec<ScoreTerm>> {
if field_to_field_values.is_empty() {
return Err(TantivyError::InvalidArgument("Cannot create more like this query on empty field values. The document may not have stored fields".to_string()));
}
let mut field_to_term_freq_map = HashMap::new();
for (field, field_values) in field_to_field_values {
self.add_term_frequencies(searcher, *field, field_values, &mut field_to_term_freq_map)?;
}
self.create_score_term(searcher, field_to_term_freq_map)
}
/// Computes the frequency of values for a field while updating the term frequencies
/// Note: A FieldValue can be made up of multiple terms.
/// We are interested in extracting terms within FieldValue
fn add_term_frequencies(
&self,
searcher: &Searcher,
field: Field,
field_values: &[FieldValue],
term_frequencies: &mut HashMap<Term, usize>,
) -> Result<()> {
let schema = searcher.schema();
let tokenizer_manager = searcher.index().tokenizers();
let field_entry = schema.get_field_entry(field);
if !field_entry.is_indexed() {
return Ok(());
}
// extract the raw value, possibly tokenizing & filtering to update the term frequency map
match field_entry.field_type() {
FieldType::HierarchicalFacet(_) => {
let facets: Vec<&str> = field_values
.iter()
.map(|field_value| match *field_value.value() {
Value::Facet(ref facet) => Ok(facet.encoded_str()),
_ => Err(TantivyError::InvalidArgument(
"invalid field value".to_string(),
)),
})
.collect::<Result<Vec<_>>>()?;
for fake_str in facets {
FacetTokenizer.token_stream(fake_str).process(&mut |token| {
if self.is_noise_word(token.text.clone()) {
let term = Term::from_field_text(field, &token.text);
*term_frequencies.entry(term).or_insert(0) += 1;
}
});
}
}
FieldType::Str(text_options) => {
let mut token_streams: Vec<BoxTokenStream> = vec![];
let mut offsets = vec![];
let mut total_offset = 0;
for field_value in field_values {
match field_value.value() {
Value::PreTokStr(tok_str) => {
offsets.push(total_offset);
if let Some(last_token) = tok_str.tokens.last() {
total_offset += last_token.offset_to;
}
token_streams.push(PreTokenizedStream::from(tok_str.clone()).into());
}
Value::Str(ref text) => {
if let Some(tokenizer) = text_options
.get_indexing_options()
.map(|text_indexing_options| {
text_indexing_options.tokenizer().to_string()
})
.and_then(|tokenizer_name| tokenizer_manager.get(&tokenizer_name))
{
offsets.push(total_offset);
total_offset += text.len();
//let v = text.clone();
token_streams.push(tokenizer.token_stream(text));
}
}
_ => (),
}
}
for mut token_stream in token_streams {
token_stream.process(&mut |token| {
if !self.is_noise_word(token.text.clone()) {
let term = Term::from_field_text(field, &token.text);
*term_frequencies.entry(term).or_insert(0) += 1;
}
});
}
}
FieldType::U64(_) => {
for field_value in field_values {
let val = field_value
.value()
.u64_value()
.ok_or(TantivyError::InvalidArgument("invalid value".to_string()))?;
if !self.is_noise_word(val.to_string()) {
let term = Term::from_field_u64(field, val);
*term_frequencies.entry(term).or_insert(0) += 1;
}
}
}
FieldType::Date(_) => {
for field_value in field_values {
// TODO: Ask if this is the semantic (timestamp) we want
let val = field_value
.value()
.date_value()
.ok_or(TantivyError::InvalidArgument("invalid value".to_string()))?
.timestamp();
if !self.is_noise_word(val.to_string()) {
let term = Term::from_field_i64(field, val);
*term_frequencies.entry(term).or_insert(0) += 1;
}
}
}
FieldType::I64(_) => {
for field_value in field_values {
let val = field_value
.value()
.i64_value()
.ok_or(TantivyError::InvalidArgument("invalid value".to_string()))?;
if !self.is_noise_word(val.to_string()) {
let term = Term::from_field_i64(field, val);
*term_frequencies.entry(term).or_insert(0) += 1;
}
}
}
FieldType::F64(_) => {
for field_value in field_values {
let val = field_value
.value()
.f64_value()
.ok_or(TantivyError::InvalidArgument("invalid value".to_string()))?;
if !self.is_noise_word(val.to_string()) {
let term = Term::from_field_f64(field, val);
*term_frequencies.entry(term).or_insert(0) += 1;
}
}
}
_ => {}
}
Ok(())
}
/// Determines if the term is likely to be of interest based on "more-like-this" settings
fn is_noise_word(&self, word: String) -> bool {
let word_length = word.len();
if word_length == 0 {
return true;
}
if self
.min_word_length
.map(|min| word_length < min)
.unwrap_or(false)
{
return true;
}
if self
.max_word_length
.map(|max| word_length > max)
.unwrap_or(false)
{
return true;
}
return self.stop_words.contains(&word);
}
/// Couputes the score for each term while ignoring not useful terms
fn create_score_term(
&self,
searcher: &Searcher,
per_field_term_frequencies: HashMap<Term, usize>,
) -> Result<Vec<ScoreTerm>> {
let mut score_terms = BinaryHeap::new();
let num_docs = searcher
.segment_readers()
.iter()
.map(|x| x.num_docs() as u64)
.sum::<u64>();
for (term, term_frequency) in per_field_term_frequencies.into_iter() {
// ignore terms with less than min_term_frequency
if self
.min_term_frequency
.map(|x| term_frequency < x)
.unwrap_or(false)
{
continue;
}
let doc_freq = searcher.doc_freq(&term)?;
// ignore terms with less than min_doc_frequency
if self
.min_doc_frequency
.map(|x| doc_freq < x)
.unwrap_or(false)
{
continue;
}
// ignore terms with more than max_doc_frequency
if self
.max_doc_frequency
.map(|x| doc_freq > x)
.unwrap_or(false)
{
continue;
}
// ignore terms with zero frequency
if doc_freq == 0 {
continue;
}
// compute similarity & score
let idf = self.idf(doc_freq, num_docs);
let score = (term_frequency as f32) * idf;
score_terms.push(ScoreTerm::new(term, score));
}
// limit ourself to max_query terms. we need to sort so to avoid discarding important terms
let score_terms = if let Some(max_query_terms) = self.max_query_terms {
let max_num_terms = std::cmp::min(max_query_terms, score_terms.len());
score_terms
.into_sorted_vec()
.into_iter()
.take(max_num_terms)
.collect()
} else {
score_terms.into_sorted_vec()
};
Ok(score_terms)
}
/// Computes the similarity
fn idf(&self, doc_freq: u64, doc_count: u64) -> f32 {
let x = ((doc_count - doc_freq) as f32 + 0.5) / (doc_freq as f32 + 0.5);
(1f32 + x).ln()
}
}

5
src/query/mlt/mod.rs Normal file
View File

@@ -0,0 +1,5 @@
mod mlt;
mod query;
pub use self::mlt::MoreLikeThis;
pub use self::query::{MoreLikeThisQuery};

284
src/query/mlt/query.rs Normal file
View File

@@ -0,0 +1,284 @@
use super::MoreLikeThis;
use crate::{
query::{Query, Weight},
schema::{Field, FieldValue},
DocAddress, Result, Searcher, TantivyError,
};
/// A query that matches all of the documents similar to a document
/// or a set of field values provided.
///
/// # Examples
///
/// ```
/// use tantivy::DocAddress;
/// use tantivy::query::MoreLikeThisQuery;
///
/// let query = MoreLikeThisQuery::builder()
/// .with_min_doc_frequency(1)
/// .with_max_doc_frequency(10)
/// .with_min_term_frequency(1)
/// .with_min_word_length(2)
/// .with_max_word_length(5)
/// .with_boost_factor(1.0)
/// .with_stop_words(vec!["for".to_string()])
/// .with_document(DocAddress::new(2, 1));
///
/// ```
#[derive(Debug, Clone)]
pub struct MoreLikeThisQuery {
mlt: MoreLikeThis,
doc_address: Option<DocAddress>,
doc_fields: Option<Vec<(Field, Vec<FieldValue>)>>,
}
impl MoreLikeThisQuery {
/// Creates a new builder.
pub fn builder() -> MoreLikeThisQueryBuilder {
MoreLikeThisQueryBuilder::default()
}
}
impl Query for MoreLikeThisQuery {
fn weight(&self, searcher: &Searcher, scoring_enabled: bool) -> Result<Box<dyn Weight>> {
if let Some(doc_address) = self.doc_address {
return self
.mlt
.query_with_document(searcher, doc_address)?
.weight(searcher, scoring_enabled);
}
if let Some(ref doc_fields) = self.doc_fields {
return self
.mlt
.query_with_document_fields(searcher, doc_fields)?
.weight(searcher, scoring_enabled);
}
Err(TantivyError::InvalidArgument("".to_string()))
}
}
/// The builder for more-like-this query
#[derive(Debug, Clone)]
pub struct MoreLikeThisQueryBuilder {
mlt: MoreLikeThis,
}
impl Default for MoreLikeThisQueryBuilder {
fn default() -> Self {
Self {
mlt: MoreLikeThis::default(),
}
}
}
impl MoreLikeThisQueryBuilder {
/// Sets the minimum document frequency.
///
/// The resulting query will ignore words which do not occur
/// in at least this many docs.
pub fn with_min_doc_frequency(mut self, value: u64) -> Self {
self.mlt.min_doc_frequency = Some(value);
self
}
/// Sets the maximum document frequency.
///
/// The resulting query will ignore words which occur
/// in more than this many docs.
pub fn with_max_doc_frequency(mut self, value: u64) -> Self {
self.mlt.max_doc_frequency = Some(value);
self
}
/// Sets the minimum term frequency.
///
/// The resulting query will ignore words less
/// frequent that this number.
pub fn with_min_term_frequency(mut self, value: usize) -> Self {
self.mlt.min_term_frequency = Some(value);
self
}
/// Sets the maximum query terms.
///
/// The resulting query will not return a query with more clause than this.
pub fn with_max_query_terms(mut self, value: usize) -> Self {
self.mlt.max_query_terms = Some(value);
self
}
/// Sets the minimum word length.
///
/// The resulting query will ignore words shorter than this length.
pub fn with_min_word_length(mut self, value: usize) -> Self {
self.mlt.min_word_length = Some(value);
self
}
/// Sets the maximum word length.
///
/// The resulting query will ignore words longer than this length.
pub fn with_max_word_length(mut self, value: usize) -> Self {
self.mlt.max_word_length = Some(value);
self
}
/// Sets the boost factor
///
/// The boost factor used by the resulting query for boosting terms.
pub fn with_boost_factor(mut self, value: f32) -> Self {
self.mlt.boost_factor = Some(value);
self
}
/// Sets the set of stop words
///
/// The resulting query will ignore these set of words.
pub fn with_stop_words(mut self, value: Vec<String>) -> Self {
self.mlt.stop_words = value;
self
}
/// Sets the document address
/// Returns the constructed [`MoreLikeThisQuery`]
///
/// This document will be used to collect field values, extract frequent terms
/// needed for composing the query.
///
/// Note that field values will only be collected from stored fields in the index.
/// You can construct your own field values from any source.
pub fn with_document(self, doc_address: DocAddress) -> MoreLikeThisQuery {
MoreLikeThisQuery {
mlt: self.mlt,
doc_address: Some(doc_address),
doc_fields: None,
}
}
/// Sets the document fields
/// Returns the constructed [`MoreLikeThisQuery`]
///
/// This represents the list field values possibly collected from multiple documents
/// that will be used to compose the resulting query.
/// This interface is meant to be used when you want to provide your own set of fields
/// not necessarily from a specific document.
pub fn with_document_fields(
self,
doc_fields: Vec<(Field, Vec<FieldValue>)>,
) -> MoreLikeThisQuery {
MoreLikeThisQuery {
mlt: self.mlt,
doc_address: None,
doc_fields: Some(doc_fields),
}
}
}
#[cfg(test)]
mod tests {
use super::MoreLikeThisQuery;
use crate::collector::TopDocs;
use crate::schema::{Schema, STORED, TEXT};
use crate::DocAddress;
use crate::Index;
fn create_test_index() -> Index {
let mut schema_builder = Schema::builder();
let title = schema_builder.add_text_field("title", TEXT);
let body = schema_builder.add_text_field("body", TEXT | STORED);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
index_writer.add_document(doc!(title => "aaa", body => "the old man and the sea"));
index_writer.add_document(doc!(title => "bbb", body => "an old man sailing on the sea"));
index_writer.add_document(doc!(title => "ccc", body=> "send this message to alice"));
index_writer.add_document(doc!(title => "ddd", body=> "a lady was riding and old bike"));
index_writer.add_document(doc!(title => "eee", body=> "Yes, my lady."));
index_writer.commit().unwrap();
index
}
#[test]
fn test_more_like_this_query_builder() {
// default settings
let query = MoreLikeThisQuery::builder().with_document_fields(vec![]);
assert_eq!(query.mlt.min_doc_frequency, Some(5));
assert_eq!(query.mlt.max_doc_frequency, None);
assert_eq!(query.mlt.min_term_frequency, Some(2));
assert_eq!(query.mlt.max_query_terms, Some(25));
assert_eq!(query.mlt.min_word_length, None);
assert_eq!(query.mlt.max_word_length, None);
assert_eq!(query.mlt.boost_factor, Some(1.0));
assert_eq!(query.mlt.stop_words, Vec::<String>::new());
assert_eq!(query.doc_fields, Some(vec![]));
assert_eq!(query.doc_address, None);
// custom settings
let query = MoreLikeThisQuery::builder()
.with_min_doc_frequency(2)
.with_max_doc_frequency(5)
.with_min_term_frequency(2)
.with_min_word_length(2)
.with_max_word_length(4)
.with_boost_factor(0.5)
.with_stop_words(vec!["all".to_string(), "for".to_string()])
.with_document(DocAddress::new(1, 2));
assert_eq!(query.mlt.min_doc_frequency, Some(2));
assert_eq!(query.mlt.max_doc_frequency, Some(5));
assert_eq!(query.mlt.min_term_frequency, Some(2));
assert_eq!(query.mlt.min_word_length, Some(2));
assert_eq!(query.mlt.max_word_length, Some(4));
assert_eq!(query.mlt.boost_factor, Some(0.5));
assert_eq!(
query.mlt.stop_words,
vec!["all".to_string(), "for".to_string()]
);
assert_eq!(query.doc_fields, None);
assert_eq!(query.doc_address, Some(DocAddress::new(1, 2)));
}
#[test]
fn test_more_like_this_query() {
let index = create_test_index();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
// search base 1st doc with words [sea, and] skipping [old]
let query = MoreLikeThisQuery::builder()
.with_min_doc_frequency(1)
.with_max_doc_frequency(10)
.with_min_term_frequency(1)
.with_min_word_length(2)
.with_max_word_length(5)
.with_boost_factor(1.0)
.with_stop_words(vec!["old".to_string()])
.with_document(DocAddress::new(0, 0));
let top_docs = searcher.search(&query, &TopDocs::with_limit(5)).unwrap();
let mut doc_ids: Vec<_> = top_docs.iter().map(|item| item.1.doc_id).collect();
doc_ids.sort();
assert_eq!(doc_ids.len(), 3);
assert_eq!(doc_ids, vec![0, 1, 3]);
// search base 5th doc with words [lady]
let query = MoreLikeThisQuery::builder()
.with_min_doc_frequency(1)
.with_max_doc_frequency(10)
.with_min_term_frequency(1)
.with_min_word_length(2)
.with_max_word_length(5)
.with_boost_factor(1.0)
.with_document(DocAddress::new(0, 4));
let top_docs = searcher.search(&query, &TopDocs::with_limit(5)).unwrap();
let mut doc_ids: Vec<_> = top_docs.iter().map(|item| item.1.doc_id).collect();
doc_ids.sort();
assert_eq!(doc_ids.len(), 2);
assert_eq!(doc_ids, vec![3, 4]);
}
}

View File

@@ -11,6 +11,7 @@ mod exclude;
mod explanation;
mod fuzzy_query;
mod intersection;
mod mlt;
mod phrase_query;
mod query;
mod query_parser;
@@ -45,6 +46,7 @@ pub use self::explanation::Explanation;
pub(crate) use self::fuzzy_query::DfaWrapper;
pub use self::fuzzy_query::FuzzyTermQuery;
pub use self::intersection::intersect_scorers;
pub use self::mlt::MoreLikeThisQuery;
pub use self::phrase_query::PhraseQuery;
pub use self::query::{Query, QueryClone};
pub use self::query_parser::QueryParser;