From 29abea07d460bc203e98fdb9d322e7e24e3ffc20 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Sun, 14 Aug 2016 10:26:56 +0900 Subject: [PATCH] Refactoring --- src/postings/mod.rs | 2 - .../daat_multiterm_scorer.rs} | 53 ++++++++++------- src/query/mod.rs | 17 ++++-- src/query/multi_term_accumulator.rs | 4 ++ src/query/multi_term_explainer.rs | 36 ------------ src/query/multi_term_query.rs | 19 +++--- src/query/multi_term_scorer.rs | 11 ---- src/query/scorer.rs | 6 +- src/query/similarity.rs | 8 +++ src/query/similarity_explainer.rs | 48 +++++++++++++++ src/query/tfidf.rs | 58 +++++++++---------- src/schema/schema.rs | 9 --- 12 files changed, 144 insertions(+), 127 deletions(-) rename src/{postings/union_postings.rs => query/daat_multiterm_scorer.rs} (81%) create mode 100644 src/query/multi_term_accumulator.rs delete mode 100644 src/query/multi_term_explainer.rs delete mode 100644 src/query/multi_term_scorer.rs create mode 100644 src/query/similarity.rs create mode 100644 src/query/similarity_explainer.rs diff --git a/src/postings/mod.rs b/src/postings/mod.rs index 7480bd129..c4d55bd0f 100644 --- a/src/postings/mod.rs +++ b/src/postings/mod.rs @@ -9,13 +9,11 @@ mod segment_postings; mod intersection; mod offset_postings; mod freq_handler; -mod union_postings; mod docset; mod scored_docset; mod segment_postings_option; pub use self::docset::{SkipResult, DocSet}; -pub use self::union_postings::UnionPostings; pub use self::offset_postings::OffsetPostings; pub use self::recorder::{Recorder, NothingRecorder, TermFrequencyRecorder, TFAndPositionRecorder}; pub use self::serializer::PostingsSerializer; diff --git a/src/postings/union_postings.rs b/src/query/daat_multiterm_scorer.rs similarity index 81% rename from src/postings/union_postings.rs rename to src/query/daat_multiterm_scorer.rs index 59f96ffed..60b132579 100644 --- a/src/postings/union_postings.rs +++ b/src/query/daat_multiterm_scorer.rs @@ -2,10 +2,13 @@ use DocId; use postings::{Postings, DocSet}; use std::cmp::Ordering; use std::collections::BinaryHeap; -use query::MultiTermAccumulator; +use query::MultiTermAccumulator; +use query::Similarity; use fastfield::U32FastFieldReader; use query::Occur; use std::iter; +use super::Scorer; +use Score; #[derive(Eq, PartialEq)] struct HeapItem(DocId, u32); @@ -55,27 +58,27 @@ impl Filter { } } -pub struct UnionPostings { +pub struct DAATMultiTermScorer { fieldnorm_readers: Vec, postings: Vec, term_frequencies: Vec, queue: BinaryHeap, doc: DocId, - scorer: TAccumulator, + similarity: TAccumulator, filter: Filter, } -impl UnionPostings { +impl DAATMultiTermScorer { fn new_non_empty( fieldnorm_readers: Vec, postings: Vec, - scorer: TAccumulator, + similarity: TAccumulator, filter: Filter - ) -> UnionPostings { + ) -> DAATMultiTermScorer { let mut term_frequencies: Vec = iter::repeat(0u32).take(postings.len()).collect(); let heap_items: Vec = postings .iter() @@ -88,18 +91,18 @@ impl UnionPostings, scorer: TAccumulator) -> UnionPostings { + pub fn new(postings_and_fieldnorms: Vec<(Occur, TPostings, U32FastFieldReader)>, similarity: TAccumulator) -> DAATMultiTermScorer { let mut postings = Vec::new(); let mut fieldnorm_readers = Vec::new(); let mut occurs = Vec::new(); @@ -111,12 +114,12 @@ impl UnionPostings &TAccumulator { - &self.scorer + &self.similarity } fn advance_head(&mut self,) { @@ -138,11 +141,17 @@ impl UnionPostings DocSet for UnionPostings { +impl Scorer for DAATMultiTermScorer { + fn score(&self,) -> Score { + self.similarity.score() + } +} + +impl DocSet for DAATMultiTermScorer { fn advance(&mut self,) -> bool { loop { - self.scorer.clear(); + self.similarity.clear(); let mut ord_bitset = 0u64; match self.queue.peek() { Some(&HeapItem(doc, ord)) => { @@ -150,7 +159,7 @@ impl DocSet for UnionPo let ord: usize = ord as usize; let fieldnorm = self.get_field_norm(ord, doc); let tf = self.term_frequencies[ord]; - self.scorer.update(ord, tf, fieldnorm); + self.similarity.update(ord, tf, fieldnorm); ord_bitset |= 1 << ord; } None => { @@ -168,7 +177,7 @@ impl DocSet for UnionPo let peek_ord: usize = peek_ord as usize; let peek_tf = self.term_frequencies[peek_ord]; let peek_fieldnorm = self.get_field_norm(peek_ord, peek_doc); - self.scorer.update(peek_ord, peek_tf, peek_fieldnorm); + self.similarity.update(peek_ord, peek_tf, peek_fieldnorm); ord_bitset |= 1 << peek_ord; } } @@ -193,7 +202,7 @@ mod tests { use super::*; use postings::{DocSet, VecPostings}; - use query::TfIdfScorer; + use query::TfIdf; use query::Scorer; use directory::ReadOnlySource; use directory::SharedVec; @@ -225,21 +234,21 @@ mod tests { let right_fieldnorms = create_u32_fastfieldreader(Field(2), vec!(15,25,35)); let left = VecPostings::from(vec!(1, 2, 3)); let right = VecPostings::from(vec!(1, 3, 8)); - let multi_term_scorer = TfIdfScorer::new(vec!(0f32, 1f32, 2f32), vec!(1f32, 4f32)); - let mut union = UnionPostings::new( + let tfidf = TfIdf::new(vec!(0f32, 1f32, 2f32), vec!(1f32, 4f32)); + let mut union = DAATMultiTermScorer::new( vec!( (Occur::Should, left, left_fieldnorms), (Occur::Should, right, right_fieldnorms), ), - multi_term_scorer + tfidf ); assert_eq!(union.next(), Some(1u32)); - assert!(abs_diff(union.scorer().score(), 2.182179f32) < 0.001); + assert!(abs_diff(union.score(), 2.182179f32) < 0.001); assert_eq!(union.next(), Some(2u32)); - assert!(abs_diff(union.scorer().score(), 0.2236068) < 0.001f32); + assert!(abs_diff(union.score(), 0.2236068) < 0.001f32); assert_eq!(union.next(), Some(3u32)); assert_eq!(union.next(), Some(8u32)); - assert!(abs_diff(union.scorer().score(), 0.8944272f32) < 0.001f32); + assert!(abs_diff(union.score(), 0.8944272f32) < 0.001f32); assert!(!union.advance()); } diff --git a/src/query/mod.rs b/src/query/mod.rs index 09c6d463e..7c482b66e 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -1,21 +1,26 @@ mod query; mod multi_term_query; -mod multi_term_scorer; -mod multi_term_explainer; +mod multi_term_accumulator; +mod similarity_explainer; mod scorer; mod query_parser; mod explanation; mod tfidf; mod occur; +mod daat_multiterm_scorer; +mod similarity; + +pub use self::similarity::Similarity; + +pub use self::daat_multiterm_scorer::DAATMultiTermScorer; pub use self::occur::Occur; pub use self::query::Query; pub use self::multi_term_query::MultiTermQuery; -pub use self::multi_term_scorer::MultiTermScorer; -pub use self::multi_term_explainer::MultiTermExplainer; -pub use self::tfidf::TfIdfScorer; +pub use self::similarity_explainer::SimilarityExplainer; +pub use self::tfidf::TfIdf; pub use self::scorer::Scorer; pub use self::query_parser::QueryParser; pub use self::explanation::Explanation; -pub use self::multi_term_scorer::MultiTermAccumulator; +pub use self::multi_term_accumulator::MultiTermAccumulator; diff --git a/src/query/multi_term_accumulator.rs b/src/query/multi_term_accumulator.rs new file mode 100644 index 000000000..6685b1673 --- /dev/null +++ b/src/query/multi_term_accumulator.rs @@ -0,0 +1,4 @@ +pub trait MultiTermAccumulator { + fn update(&mut self, term_ord: usize, term_freq: u32, fieldnorm: u32); + fn clear(&mut self,); +} diff --git a/src/query/multi_term_explainer.rs b/src/query/multi_term_explainer.rs deleted file mode 100644 index 453e1561e..000000000 --- a/src/query/multi_term_explainer.rs +++ /dev/null @@ -1,36 +0,0 @@ -use super::MultiTermAccumulator; -use super::MultiTermScorer; -use super::Explanation; - -pub struct MultiTermExplainer { - scorer: TScorer, - vals: Vec<(usize, u32, u32)>, -} - -impl MultiTermExplainer { - pub fn explain_score(&self,) -> Explanation { - self.scorer.explain(&self.vals) - } -} - -impl From for MultiTermExplainer { - fn from(multi_term_scorer: TScorer) -> MultiTermExplainer { - MultiTermExplainer { - scorer: multi_term_scorer, - vals: Vec::new(), - } - } -} - -impl MultiTermAccumulator for MultiTermExplainer { - fn update(&mut self, term_ord: usize, term_freq: u32, fieldnorm: u32) { - self.vals.push((term_ord, term_freq, fieldnorm)); - self.scorer.update(term_ord, term_freq, fieldnorm); - } - - fn clear(&mut self,) { - self.vals.clear(); - self.scorer.clear(); - } -} - diff --git a/src/query/multi_term_query.rs b/src/query/multi_term_query.rs index ed2378561..48da22d1e 100644 --- a/src/query/multi_term_query.rs +++ b/src/query/multi_term_query.rs @@ -8,11 +8,10 @@ use core::searcher::Searcher; use collector::Collector; use SegmentLocalId; use core::SegmentReader; -use query::MultiTermExplainer; +use query::SimilarityExplainer; use postings::SegmentPostings; -use postings::UnionPostings; use postings::DocSet; -use query::TfIdfScorer; +use query::TfIdf; use postings::SkipResult; use ScoredDoc; use query::Scorer; @@ -21,6 +20,7 @@ use DocAddress; use query::Explanation; use query::occur::Occur; use postings::SegmentPostingsOption; +use query::DAATMultiTermScorer; #[derive(Eq, PartialEq, Debug)] @@ -30,13 +30,12 @@ pub struct MultiTermQuery { impl MultiTermQuery { - pub fn num_terms(&self,) -> usize { self.occur_terms.len() } - fn scorer(&self, searcher: &Searcher) -> TfIdfScorer { + fn scorer(&self, searcher: &Searcher) -> TfIdf { let num_terms = self.num_terms(); let num_docs = searcher.num_docs() as f32; let idfs: Vec = self.occur_terms @@ -59,7 +58,7 @@ impl MultiTermQuery { .iter() .map(|&(_, ref term)| format!("{:?}", &term)) .collect(); - let mut tfidf_scorer = TfIdfScorer::new(query_coords, idfs); + let mut tfidf_scorer = TfIdf::new(query_coords, idfs); tfidf_scorer.set_term_names(term_names); tfidf_scorer } @@ -68,7 +67,7 @@ impl MultiTermQuery { &'b self, reader: &'b SegmentReader, multi_term_scorer: TScorer, - mut timer: OpenTimer<'a>) -> Result> { + mut timer: OpenTimer<'a>) -> Result> { let mut postings_and_fieldnorms = Vec::with_capacity(self.num_terms()); { let mut decode_timer = timer.open("decode_all"); @@ -88,7 +87,7 @@ impl MultiTermQuery { // TODO putting the SHOULD at the end of the list should push the limit. return Err(Error::InvalidArgument(String::from("Limit of 64 terms was exceeded."))); } - Ok(UnionPostings::new(postings_and_fieldnorms, multi_term_scorer)) + Ok(DAATMultiTermScorer::new(postings_and_fieldnorms, multi_term_scorer)) } } @@ -120,7 +119,7 @@ impl Query for MultiTermQuery { searcher: &Searcher, doc_address: &DocAddress) -> Result { let segment_reader = &searcher.segments()[doc_address.segment_ord() as usize]; - let multi_term_scorer = MultiTermExplainer::from(self.scorer(searcher)); + let multi_term_scorer = SimilarityExplainer::from(self.scorer(searcher)); let mut timer_tree = TimerTree::new(); let mut postings = try!( self.search_segment( @@ -164,7 +163,7 @@ impl Query for MultiTermQuery { { let _collection_timer = segment_search_timer.open("collection"); while postings.advance() { - let scored_doc = ScoredDoc(postings.scorer().score(), postings.doc()); + let scored_doc = ScoredDoc(postings.score(), postings.doc()); collector.collect(scored_doc); } } diff --git a/src/query/multi_term_scorer.rs b/src/query/multi_term_scorer.rs deleted file mode 100644 index 00e5ecb5e..000000000 --- a/src/query/multi_term_scorer.rs +++ /dev/null @@ -1,11 +0,0 @@ -use query::Scorer; -use query::Explanation; - -pub trait MultiTermAccumulator { - fn update(&mut self, term_ord: usize, term_freq: u32, fieldnorm: u32); - fn clear(&mut self,); -} - -pub trait MultiTermScorer: Scorer + MultiTermAccumulator { - fn explain(&self, vals: &Vec<(usize, u32, u32)>) -> Explanation; -} diff --git a/src/query/scorer.rs b/src/query/scorer.rs index 78b58cf36..36a91cd01 100644 --- a/src/query/scorer.rs +++ b/src/query/scorer.rs @@ -1,3 +1,7 @@ -pub trait Scorer { +use DocSet; + +pub trait Scorer: DocSet { fn score(&self,) -> f32; } + + diff --git a/src/query/similarity.rs b/src/query/similarity.rs new file mode 100644 index 000000000..b911cd435 --- /dev/null +++ b/src/query/similarity.rs @@ -0,0 +1,8 @@ +use Score; +use query::Explanation; +use query::MultiTermAccumulator; + +pub trait Similarity: MultiTermAccumulator { + fn score(&self, ) -> Score; + fn explain(&self, vals: &Vec<(usize, u32, u32)>) -> Explanation; +} \ No newline at end of file diff --git a/src/query/similarity_explainer.rs b/src/query/similarity_explainer.rs new file mode 100644 index 000000000..19588e329 --- /dev/null +++ b/src/query/similarity_explainer.rs @@ -0,0 +1,48 @@ +use Score; +use super::MultiTermAccumulator; +use super::Similarity; +use super::Explanation; + +pub struct SimilarityExplainer { + scorer: TSimilarity, + vals: Vec<(usize, u32, u32)>, +} + +impl SimilarityExplainer { + pub fn explain_score(&self,) -> Explanation { + self.scorer.explain(&self.vals) + } +} + +impl From for SimilarityExplainer { + fn from(multi_term_scorer: TSimilarity) -> SimilarityExplainer { + SimilarityExplainer { + scorer: multi_term_scorer, + vals: Vec::new(), + } + } +} + +impl MultiTermAccumulator for SimilarityExplainer { + fn update(&mut self, term_ord: usize, term_freq: u32, fieldnorm: u32) { + self.vals.push((term_ord, term_freq, fieldnorm)); + self.scorer.update(term_ord, term_freq, fieldnorm); + } + + fn clear(&mut self,) { + self.vals.clear(); + self.scorer.clear(); + } +} + + +impl Similarity for SimilarityExplainer { + + fn score(&self,) -> Score { + self.scorer.score() + } + + fn explain(&self, vals: &Vec<(usize, u32, u32)>) -> Explanation { + self.scorer.explain(vals) + } +} diff --git a/src/query/tfidf.rs b/src/query/tfidf.rs index 6dd2b9f58..3f978ac78 100644 --- a/src/query/tfidf.rs +++ b/src/query/tfidf.rs @@ -1,10 +1,10 @@ +use Score; use super::MultiTermAccumulator; -use super::Scorer; -use super::MultiTermScorer; use super::Explanation; +use super::Similarity; #[derive(Clone)] -pub struct TfIdfScorer { +pub struct TfIdf { coords: Vec, idf: Vec, score: f32, @@ -12,7 +12,7 @@ pub struct TfIdfScorer { term_names: Option>, //< only here for explain } -impl MultiTermAccumulator for TfIdfScorer { +impl MultiTermAccumulator for TfIdf { #[inline(always)] fn update(&mut self, term_ord: usize, term_freq: u32, fieldnorm: u32) { @@ -28,9 +28,9 @@ impl MultiTermAccumulator for TfIdfScorer { } } -impl TfIdfScorer { - pub fn new(coords: Vec, idf: Vec) -> TfIdfScorer { - TfIdfScorer { +impl TfIdf { + pub fn new(coords: Vec, idf: Vec) -> TfIdf { + TfIdf { coords: coords, idf: idf, score: 0f32, @@ -61,14 +61,12 @@ impl TfIdfScorer { } } -impl Scorer for TfIdfScorer { +impl Similarity for TfIdf { #[inline(always)] - fn score(&self, ) -> f32 { + fn score(&self, ) -> Score { self.score * self.coord() } -} - -impl MultiTermScorer for TfIdfScorer { + fn explain(&self, vals: &Vec<(usize, u32, u32)>) -> Explanation { let score = self.score(); let mut explanation = Explanation::with_val(score); @@ -89,41 +87,41 @@ impl MultiTermScorer for TfIdfScorer { + #[cfg(test)] mod tests { use super::*; - use query::Scorer; use query::MultiTermAccumulator; - + use query::Similarity; + fn abs_diff(left: f32, right: f32) -> f32 { (right - left).abs() } #[test] - pub fn test_multiterm_scorer() { - let mut tfidf_scorer = TfIdfScorer::new(vec!(0f32, 1f32, 2f32), vec!(1f32, 4f32)); + pub fn test_tfidf() { + let mut tfidf = TfIdf::new(vec!(0f32, 1f32, 2f32), vec!(1f32, 4f32)); { - tfidf_scorer.update(0, 1, 1); - assert!(abs_diff(tfidf_scorer.score(), 1f32) < 0.001f32); - tfidf_scorer.clear(); - + tfidf.update(0, 1, 1); + assert!(abs_diff(tfidf.score(), 1f32) < 0.001f32); + tfidf.clear(); } { - tfidf_scorer.update(1, 1, 1); - assert_eq!(tfidf_scorer.score(), 4f32); - tfidf_scorer.clear(); + tfidf.update(1, 1, 1); + assert_eq!(tfidf.score(), 4f32); + tfidf.clear(); } { - tfidf_scorer.update(0, 2, 1); - assert!(abs_diff(tfidf_scorer.score(), 1.4142135) < 0.001f32); - tfidf_scorer.clear(); + tfidf.update(0, 2, 1); + assert!(abs_diff(tfidf.score(), 1.4142135) < 0.001f32); + tfidf.clear(); } { - tfidf_scorer.update(0, 1, 1); - tfidf_scorer.update(1, 1, 1); - assert_eq!(tfidf_scorer.score(), 10f32); - tfidf_scorer.clear(); + tfidf.update(0, 1, 1); + tfidf.update(1, 1, 1); + assert_eq!(tfidf.score(), 10f32); + tfidf.clear(); } diff --git a/src/schema/schema.rs b/src/schema/schema.rs index dd459dc56..10b9807ef 100644 --- a/src/schema/schema.rs +++ b/src/schema/schema.rs @@ -22,20 +22,11 @@ use super::*; /// ``` /// use tantivy::schema::*; /// -/// fn create_schema() -> Schema { /// let mut schema = Schema::new(); -/// let str_fieldtype = TextOptions::new(); /// let id_field = schema.add_text_field("id", STRING); -/// let url_field = schema.add_text_field("url", STRING); -/// let body_field = schema.add_text_field("body", TEXT); -/// let id_field = schema.add_text_field("id", STRING); -/// let url_field = schema.add_text_field("url", STRING); /// let title_field = schema.add_text_field("title", TEXT); /// let body_field = schema.add_text_field("body", TEXT); -/// schema -/// } /// -/// let schema = create_schema(); /// ``` #[derive(Clone, Debug)] pub struct Schema {