From e32dba1a9747ee2ed68988d910c518d8e4318229 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Mon, 10 Sep 2018 09:26:33 +0900 Subject: [PATCH] Phrase weight --- src/query/automaton_weight.rs | 42 --------- src/query/boolean_query/boolean_query.rs | 8 ++ src/query/boolean_query/boolean_weight.rs | 9 -- src/query/mod.rs | 2 - src/query/phrase_query/phrase_query.rs | 7 ++ src/query/phrase_query/phrase_weight.rs | 5 - src/query/query.rs | 5 +- src/query/range_query.rs | 5 - src/query/term_query/term_query.rs | 4 + src/query/term_query/term_weight.rs | 21 ----- src/query/weight.rs | 34 ------- src/snippet/mod.rs | 110 ++++++++-------------- 12 files changed, 61 insertions(+), 191 deletions(-) diff --git a/src/query/automaton_weight.rs b/src/query/automaton_weight.rs index 854ecb66e..9ff7b8594 100644 --- a/src/query/automaton_weight.rs +++ b/src/query/automaton_weight.rs @@ -7,7 +7,6 @@ use query::{Scorer, Weight}; use schema::{Field, IndexRecordOption}; use termdict::{TermDictionary, TermStreamer}; use Result; -use query::weight::MatchingTerms; use SkipResult; use Term; use DocId; @@ -41,47 +40,6 @@ impl Weight for AutomatonWeight where A: Automaton, { - fn matching_terms(&self, - reader: &SegmentReader, - matching_terms: &mut MatchingTerms) -> Result<()> { - let max_doc = reader.max_doc(); - let mut doc_bitset = BitSet::with_max_value(max_doc); - - let inverted_index = reader.inverted_index(self.field); - let term_dict = inverted_index.terms(); - let mut term_stream = self.automaton_stream(term_dict); - - let doc_ids = matching_terms.sorted_doc_ids(); - let mut docs_matching_current_term: Vec = vec![]; - - let mut term_buffer: Vec = vec![]; - - while term_stream.advance() { - docs_matching_current_term.clear(); - let term_info = term_stream.value(); - let mut segment_postings = inverted_index.read_postings_from_terminfo(term_info, IndexRecordOption::Basic); - for &doc_id in &doc_ids { - match segment_postings.skip_next(doc_id) { - SkipResult::Reached => { - docs_matching_current_term.push(doc_id); - } - SkipResult::OverStep => {} - SkipResult::End => {} - } - } - if !docs_matching_current_term.is_empty() { - term_buffer.clear(); - let term_ord = term_stream.term_ord(); - inverted_index.terms().ord_to_term(term_ord, &mut term_buffer); - let term = Term::from_field_bytes(self.field, &term_buffer[..]); - for &doc_id in &docs_matching_current_term { - matching_terms.add_term(doc_id, term.clone(), 1f32); - } - } - } - Ok(()) - } - fn scorer(&self, reader: &SegmentReader) -> Result> { let max_doc = reader.max_doc(); let mut doc_bitset = BitSet::with_max_value(max_doc); diff --git a/src/query/boolean_query/boolean_query.rs b/src/query/boolean_query/boolean_query.rs index 286d9f449..b92a203eb 100644 --- a/src/query/boolean_query/boolean_query.rs +++ b/src/query/boolean_query/boolean_query.rs @@ -6,6 +6,7 @@ use query::Weight; use schema::IndexRecordOption; use schema::Term; use Result; +use std::collections::BTreeSet; use Searcher; /// The boolean query combines a set of queries @@ -40,6 +41,7 @@ impl From)>> for BooleanQuery { } impl Query for BooleanQuery { + fn weight(&self, searcher: &Searcher, scoring_enabled: bool) -> Result> { let sub_weights = self.subqueries .iter() @@ -49,6 +51,12 @@ impl Query for BooleanQuery { .collect::>()?; Ok(Box::new(BooleanWeight::new(sub_weights, scoring_enabled))) } + + fn query_terms(&self, term_set: &mut BTreeSet) { + for (_occur, subquery) in &self.subqueries { + subquery.query_terms(term_set); + } + } } impl BooleanQuery { diff --git a/src/query/boolean_query/boolean_weight.rs b/src/query/boolean_query/boolean_weight.rs index 2b3348a21..575bc2991 100644 --- a/src/query/boolean_query/boolean_weight.rs +++ b/src/query/boolean_query/boolean_weight.rs @@ -13,7 +13,6 @@ use query::Weight; use std::borrow::Borrow; use std::collections::HashMap; use Result; -use query::MatchingTerms; fn scorer_union(scorers: Vec>) -> Box where @@ -108,14 +107,6 @@ impl BooleanWeight { } impl Weight for BooleanWeight { - - fn matching_terms(&self, reader: &SegmentReader, matching_terms: &mut MatchingTerms) -> Result<()> { - for (_, weight) in &self.weights { - weight.matching_terms(reader, matching_terms)?; - } - Ok(()) - } - fn scorer(&self, reader: &SegmentReader) -> Result> { if self.weights.is_empty() { Ok(Box::new(EmptyScorer)) diff --git a/src/query/mod.rs b/src/query/mod.rs index 0b6ee2adb..73a77174b 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -27,8 +27,6 @@ mod weight; mod vec_docset; pub(crate) mod score_combiner; -pub use self::weight::MatchingTerms; - pub use self::intersection::Intersection; pub use self::union::Union; diff --git a/src/query/phrase_query/phrase_query.rs b/src/query/phrase_query/phrase_query.rs index e501711ed..d103461c1 100644 --- a/src/query/phrase_query/phrase_query.rs +++ b/src/query/phrase_query/phrase_query.rs @@ -6,6 +6,7 @@ use query::Query; use query::Weight; use schema::{Field, Term}; use Result; +use std::collections::BTreeSet; /// `PhraseQuery` matches a specific sequence of words. /// @@ -107,4 +108,10 @@ impl Query for PhraseQuery { ))) } } + + fn query_terms(&self, term_set: &mut BTreeSet) { + for (_, query_term) in &self.phrase_terms { + term_set.insert(query_term.clone()); + } + } } diff --git a/src/query/phrase_query/phrase_weight.rs b/src/query/phrase_query/phrase_weight.rs index fbf43db20..69ab4e184 100644 --- a/src/query/phrase_query/phrase_weight.rs +++ b/src/query/phrase_query/phrase_weight.rs @@ -7,7 +7,6 @@ use query::Weight; use schema::IndexRecordOption; use schema::Term; use Result; -use query::MatchingTerms; pub struct PhraseWeight { phrase_terms: Vec<(usize, Term)>, @@ -32,10 +31,6 @@ impl PhraseWeight { impl Weight for PhraseWeight { - fn matching_terms(&self, reader: &SegmentReader, matching_terms: &mut MatchingTerms) -> Result<()> { - unimplemented!(); - } - fn scorer(&self, reader: &SegmentReader) -> Result> { let similarity_weight = self.similarity_weight.clone(); let field = self.phrase_terms[0].1.field(); diff --git a/src/query/query.rs b/src/query/query.rs index 7004768e4..a72c33d00 100644 --- a/src/query/query.rs +++ b/src/query/query.rs @@ -6,7 +6,8 @@ use std::fmt; use Result; use SegmentLocalId; use DocAddress; -use query::weight::MatchingTerms; +use std::collections::BTreeSet; +use Term; /// The `Query` trait defines a set of documents and a scoring method /// for those documents. @@ -60,6 +61,8 @@ pub trait Query: QueryClone + downcast::Any + fmt::Debug { Ok(result) } + fn query_terms(&self, term_set: &mut BTreeSet) {} + /// Search works as follows : /// /// First the weight object associated to the query is created. diff --git a/src/query/range_query.rs b/src/query/range_query.rs index 2b22e7cf8..06d98db66 100644 --- a/src/query/range_query.rs +++ b/src/query/range_query.rs @@ -11,7 +11,6 @@ use std::collections::Bound; use std::ops::Range; use termdict::{TermDictionary, TermStreamer}; use Result; -use query::MatchingTerms; fn map_bound TTo>( bound: &Bound, @@ -276,10 +275,6 @@ impl RangeWeight { impl Weight for RangeWeight { - fn matching_terms(&self, reader: &SegmentReader, matching_terms: &mut MatchingTerms) -> Result<()> { - unimplemented!(); - } - fn scorer(&self, reader: &SegmentReader) -> Result> { let max_doc = reader.max_doc(); let mut doc_bitset = BitSet::with_max_value(max_doc); diff --git a/src/query/term_query/term_query.rs b/src/query/term_query/term_query.rs index 9ba10b307..d6cd72288 100644 --- a/src/query/term_query/term_query.rs +++ b/src/query/term_query/term_query.rs @@ -6,6 +6,7 @@ use schema::IndexRecordOption; use Result; use Searcher; use Term; +use std::collections::BTreeSet; /// A Term query matches all of the documents /// containing a specific term. @@ -110,4 +111,7 @@ impl Query for TermQuery { fn weight(&self, searcher: &Searcher, scoring_enabled: bool) -> Result> { Ok(Box::new(self.specialized_weight(searcher, scoring_enabled))) } + fn query_terms(&self, term_set: &mut BTreeSet) { + term_set.insert(self.term.clone()); + } } diff --git a/src/query/term_query/term_weight.rs b/src/query/term_query/term_weight.rs index aa1b5e456..162abe519 100644 --- a/src/query/term_query/term_weight.rs +++ b/src/query/term_query/term_weight.rs @@ -9,7 +9,6 @@ use schema::IndexRecordOption; use Result; use Term; use SkipResult; -use query::weight::MatchingTerms; pub struct TermWeight { term: Term, @@ -40,26 +39,6 @@ impl Weight for TermWeight { } } - - fn matching_terms(&self, - reader: &SegmentReader, - matching_terms: &mut MatchingTerms) -> Result<()> { - let doc_ids = matching_terms.sorted_doc_ids(); - let mut scorer = self.scorer(reader)?; - for doc_id in doc_ids { - match scorer.skip_next(doc_id) { - SkipResult::Reached => { - matching_terms.add_term(doc_id, self.term.clone(), 1f32); - } - SkipResult::OverStep => {} - SkipResult::End => { - break; - } - } - } - Ok(()) - } - fn count(&self, reader: &SegmentReader) -> Result { if reader.num_deleted_docs() == 0 { let field = self.term.field(); diff --git a/src/query/weight.rs b/src/query/weight.rs index 8a12c01da..8bca9ad16 100644 --- a/src/query/weight.rs +++ b/src/query/weight.rs @@ -7,36 +7,6 @@ use Term; use std::collections::BTreeMap; use std::collections::HashMap; -pub struct MatchingTerms { - doc_to_terms: BTreeMap> -} - -impl MatchingTerms { - pub fn from_doc_ids(doc_ids: &[DocId]) -> MatchingTerms { - MatchingTerms { - doc_to_terms: doc_ids - .iter() - .cloned() - .map(|doc_id| (doc_id, HashMap::default())) - .collect() - } - } - - pub fn terms_for_doc(&self, doc_id: DocId) -> Option<&HashMap> { - self.doc_to_terms.get(&doc_id) - } - - pub fn sorted_doc_ids(&self) -> Vec { - self.doc_to_terms.keys().cloned().collect() - } - - pub fn add_term(&mut self, doc_id: DocId, term: Term, score: f32) { - if let Some(terms) = self.doc_to_terms.get_mut(&doc_id) { - terms.insert(term, score); - } - } -} - /// A Weight is the specialization of a Query /// for a given set of segments. /// @@ -46,10 +16,6 @@ pub trait Weight { /// See [`Query`](./trait.Query.html). fn scorer(&self, reader: &SegmentReader) -> Result>; - fn matching_terms(&self, reader: &SegmentReader, matching_terms: &mut MatchingTerms) -> Result<()> { - Ok(()) - } - /// Returns the number documents within the given `SegmentReader`. fn count(&self, reader: &SegmentReader) -> Result { Ok(self.scorer(reader)?.count()) diff --git a/src/snippet/mod.rs b/src/snippet/mod.rs index ffd6613e3..39d1ff89c 100644 --- a/src/snippet/mod.rs +++ b/src/snippet/mod.rs @@ -11,11 +11,11 @@ use query::Query; use DocAddress; use DocId; use Searcher; -use query::MatchingTerms; use schema::Field; use std::collections::HashMap; use SegmentLocalId; use error::TantivyError; +use std::collections::BTreeSet; #[derive(Debug)] pub struct HighlightSection { @@ -129,9 +129,9 @@ impl Snippet { /// Fragments must be valid in the sense that `&text[fragment.start..fragment.stop]`\ /// has to be a valid string. fn search_fragments<'a>( - tokenizer: Box, + tokenizer: &BoxedTokenizer, text: &'a str, - terms: BTreeMap, + terms: &BTreeMap, max_num_chars: usize, ) -> Vec { let mut token_stream = tokenizer.token_stream(text); @@ -199,75 +199,41 @@ fn select_best_fragment_combination<'a>( } +const DEFAULT_MAX_NUM_CHARS: usize = 150; - -fn compute_matching_terms(query: &Query, searcher: &Searcher, doc_addresses: &[DocAddress]) -> Result> { - let weight = query.weight(searcher, false)?; - let mut doc_groups = doc_addresses - .iter() - .group_by(|doc_address| doc_address.0); - let mut matching_terms_per_segment: HashMap = HashMap::new(); - for (segment_ord, doc_addrs) in doc_groups.into_iter() { - let doc_addrs_vec: Vec = doc_addrs.map(|doc_addr| doc_addr.1).collect(); - let mut matching_terms = MatchingTerms::from_doc_ids(&doc_addrs_vec[..]); - let segment_reader = searcher.segment_reader(segment_ord); - weight.matching_terms(segment_reader, &mut matching_terms)?; - matching_terms_per_segment.insert(segment_ord, matching_terms); - } - Ok(matching_terms_per_segment) +pub struct SnippetGenerator { + terms_text: BTreeMap, + tokenizer: Box, + max_num_chars: usize } -pub fn generate_snippet( - searcher: &Searcher, - query: &Query, - field: Field, - doc_addresses: &[DocAddress], - max_num_chars: usize) -> Result> { - - let mut doc_address_ords: Vec = (0..doc_addresses.len()).collect(); - doc_address_ords.sort_by_key(|k| doc_addresses[*k]); - - let mut snippets = vec![]; - let matching_terms_per_segment_local_id = compute_matching_terms(query, searcher, doc_addresses)?; - - for &doc_address_ord in &doc_address_ords { - let doc_address = doc_addresses[doc_address_ord]; - let segment_ord: u32 = doc_address.segment_ord(); - let doc = searcher.doc(&doc_address)?; - - let mut text = String::new(); - for value in doc.get_all(field) { - text.push_str(value.text()); - } - - - if let Some(matching_terms) = matching_terms_per_segment_local_id.get(&segment_ord) { - let tokenizer = searcher.index().tokenizer_for_field(field)?; - if let Some(terms) = matching_terms.terms_for_doc(doc_address.doc()) { - let terms: BTreeMap = terms - .iter() - .map(|(term, score)| (term.text().to_string(), *score)) - .collect(); - let fragment_candidates = search_fragments(tokenizer, - &text, - terms, - max_num_chars); - let snippet = select_best_fragment_combination(fragment_candidates, &text); - snippets.push(snippet); - } else { - snippets.push(Snippet::empty()); - } - } else { - - } +impl SnippetGenerator { + pub fn new(searcher: &Searcher, + query: &Query, + field: Field) -> Result { + let mut terms = BTreeSet::new(); + query.query_terms(&mut terms); + let terms_text: BTreeMap = terms.into_iter() + .filter(|term| term.field() == field) + .map(|term| (term.text().to_string(), 1f32)) + .collect(); + let tokenizer = searcher.index().tokenizer_for_field(field)?; + Ok(SnippetGenerator { + terms_text, + tokenizer, + max_num_chars: DEFAULT_MAX_NUM_CHARS + }) } - // reorder the snippets - for i in 0..doc_addresses.len() { - snippets.swap(i, doc_address_ords[i]); - } + pub fn snippet(&self, text: &str) -> Snippet { + let fragment_candidates = search_fragments(&*self.tokenizer, + &text, + &self.terms_text, + self.max_num_chars); + let snippet = select_best_fragment_combination(fragment_candidates, &text); + snippet - Ok(snippets) + } } #[cfg(test)] @@ -294,7 +260,7 @@ Rust won first place for \"most loved programming language\" in the Stack Overfl terms.insert(String::from("rust"), 1.0); terms.insert(String::from("language"), 0.9); - let fragments = search_fragments(boxed_tokenizer, &text, terms, 100); + let fragments = search_fragments(&*boxed_tokenizer, &text, &terms, 100); assert_eq!(fragments.len(), 7); { let first = fragments.iter().nth(0).unwrap(); @@ -315,7 +281,7 @@ Rust won first place for \"most loved programming language\" in the Stack Overfl let mut terms = BTreeMap::new(); terms.insert(String::from("c"), 1.0); - let fragments = search_fragments(boxed_tokenizer, &text, terms, 3); + let fragments = search_fragments(&*boxed_tokenizer, &text, &terms, 3); assert_eq!(fragments.len(), 1); { @@ -339,7 +305,7 @@ Rust won first place for \"most loved programming language\" in the Stack Overfl let mut terms = BTreeMap::new(); terms.insert(String::from("f"), 1.0); - let fragments = search_fragments(boxed_tokenizer, &text, terms, 3); + let fragments = search_fragments(&*boxed_tokenizer, &text, &terms, 3); assert_eq!(fragments.len(), 2); { @@ -364,7 +330,7 @@ Rust won first place for \"most loved programming language\" in the Stack Overfl terms.insert(String::from("f"), 1.0); terms.insert(String::from("a"), 0.9); - let fragments = search_fragments(boxed_tokenizer, &text, terms, 7); + let fragments = search_fragments(&*boxed_tokenizer, &text, &terms, 7); assert_eq!(fragments.len(), 2); { @@ -388,7 +354,7 @@ Rust won first place for \"most loved programming language\" in the Stack Overfl let mut terms = BTreeMap::new(); terms.insert(String::from("z"), 1.0); - let fragments = search_fragments(boxed_tokenizer, &text, terms, 3); + let fragments = search_fragments(&*boxed_tokenizer, &text, &terms, 3); assert_eq!(fragments.len(), 0); @@ -404,7 +370,7 @@ Rust won first place for \"most loved programming language\" in the Stack Overfl let text = "a b c d"; let terms = BTreeMap::new(); - let fragments = search_fragments(boxed_tokenizer, &text, terms, 3); + let fragments = search_fragments(&*boxed_tokenizer, &text, &terms, 3); assert_eq!(fragments.len(), 0); let snippet = select_best_fragment_combination(fragments, &text);