From d4bbec66311c5d06979ba5f9d1acbb7ebc8e5b99 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Sat, 6 Aug 2016 18:18:45 +0900 Subject: [PATCH] Safer interface for union_postings --- src/core/index.rs | 15 +-- src/error.rs | 6 +- src/postings/segment_postings.rs | 2 + src/postings/union_postings.rs | 36 ++++--- src/query/mod.rs | 12 ++- src/query/multi_term_explainer.rs | 36 +++++++ src/query/multi_term_query.rs | 40 +++++--- src/query/multi_term_scorer.rs | 155 ------------------------------ src/query/query.rs | 5 +- src/query/tfidf.rs | 132 +++++++++++++++++++++++++ 10 files changed, 231 insertions(+), 208 deletions(-) create mode 100644 src/query/multi_term_explainer.rs create mode 100644 src/query/tfidf.rs diff --git a/src/core/index.rs b/src/core/index.rs index 82b0b609e..a149c8837 100644 --- a/src/core/index.rs +++ b/src/core/index.rs @@ -1,5 +1,4 @@ use Result; -use Error; use std::path::{PathBuf, Path}; use schema::Schema; use DocId; @@ -250,16 +249,10 @@ impl Segment { pub fn open_read(&self, component: SegmentComponent) -> Result { let path = self.relative_path(component); - let directory_lock = self.index.directory.read(); - match directory_lock { - Ok(directory) => { - directory.open_read(&path) - .map_err(From::from) - } - Err(_) => { - Err(Error::Poisoned) - } - } + let directory = try!(self.index.directory.read()); + let source = try!(directory.open_read(&path)); + Ok(source) + } pub fn open_write(&self, component: SegmentComponent) -> Result { diff --git a/src/error.rs b/src/error.rs index af2b6e110..afeb5abdc 100644 --- a/src/error.rs +++ b/src/error.rs @@ -5,8 +5,6 @@ use std::error; use std::sync::PoisonError; use directory::OpenError; - - #[derive(Debug)] pub enum Error { OpenError(OpenError), @@ -20,8 +18,6 @@ impl Error { pub fn make_other(e: E) -> Error { Error::Other(Box::new(e)) } - - } impl From for Error { @@ -31,7 +27,7 @@ impl From for Error { } impl From> for Error { - fn from(poison_error: PoisonError) -> Error { + fn from(_: PoisonError) -> Error { Error::Poisoned } } diff --git a/src/postings/segment_postings.rs b/src/postings/segment_postings.rs index 8334a184a..eeed4e6bc 100644 --- a/src/postings/segment_postings.rs +++ b/src/postings/segment_postings.rs @@ -66,6 +66,7 @@ impl<'a> DocSet for SegmentPostings<'a> { // goes to the next element. // next needs to be called a first time to point to the correct element. + #[inline(always)] fn advance(&mut self,) -> bool { self.cur += Wrapping(1); if self.cur.0 >= self.len { @@ -77,6 +78,7 @@ impl<'a> DocSet for SegmentPostings<'a> { return true; } + #[inline(always)] fn doc(&self,) -> DocId { self.block_decoder.output(self.index_within_block()) } diff --git a/src/postings/union_postings.rs b/src/postings/union_postings.rs index 27b9e49bb..ce683245f 100644 --- a/src/postings/union_postings.rs +++ b/src/postings/union_postings.rs @@ -22,7 +22,7 @@ impl Ord for HeapItem { } pub struct UnionPostings { - fieldnorms_readers: Vec, + fieldnorm_readers: Vec, postings: Vec, term_frequencies: Vec, queue: BinaryHeap, @@ -31,14 +31,9 @@ pub struct UnionPostings UnionPostings { - - pub fn new(fieldnorms_reader: Vec, mut postings: Vec, scorer: TAccumulator) -> UnionPostings { - let num_postings = postings.len(); - assert_eq!(fieldnorms_reader.len(), num_postings); - for posting in &mut postings { - assert!(posting.advance()); - } - let mut term_frequencies: Vec = iter::repeat(0u32).take(num_postings).collect(); + + fn new_non_empty(fieldnorm_readers: Vec, postings: Vec, scorer: TAccumulator) -> UnionPostings { + let mut term_frequencies: Vec = iter::repeat(0u32).take(postings.len()).collect(); let heap_items: Vec = postings .iter() .map(|posting| { @@ -50,9 +45,8 @@ impl UnionPostings UnionPostings, scorer: TAccumulator) -> UnionPostings { + let mut postings = Vec::new(); + let mut fieldnorm_readers = Vec::new(); + for (mut posting, fieldnorm_reader) in postings_and_fieldnorms { + if posting.advance() { + postings.push(posting); + fieldnorm_readers.push(fieldnorm_reader); + } + } + UnionPostings::new_non_empty(fieldnorm_readers, postings, scorer) + } pub fn scorer(&self,) -> &TAccumulator { @@ -80,7 +86,7 @@ impl UnionPostings u32 { - self.fieldnorms_readers[ord].get(doc) + self.fieldnorm_readers[ord].get(doc) } } @@ -166,8 +172,10 @@ mod tests { let right = VecPostings::from(vec!(1, 3, 8)); let multi_term_scorer = TfIdfScorer::new(vec!(0f32, 1f32, 2f32), vec!(1f32, 4f32)); let mut union = UnionPostings::new( - vec!(left_fieldnorms, right_fieldnorms), - vec!(left, right), + vec!( + (left, left_fieldnorms), + (right, right_fieldnorms), + ), multi_term_scorer ); assert_eq!(union.next(), Some(1u32)); diff --git a/src/query/mod.rs b/src/query/mod.rs index 075db6b27..102e3387f 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -1,16 +1,22 @@ mod query; mod multi_term_query; mod multi_term_scorer; +mod multi_term_explainer; mod scorer; mod query_parser; mod explanation; +mod tfidf; pub use self::query::Query; + pub use self::multi_term_query::MultiTermQuery; pub use self::multi_term_scorer::MultiTermScorer; -pub use self::multi_term_scorer::TfIdfScorer; -pub use self::multi_term_scorer::MultiTermExplainScorer; +pub use self::multi_term_explainer::MultiTermExplainer; +pub use self::tfidf::TfIdfScorer; + pub use self::scorer::Scorer; pub use self::query_parser::QueryParser; pub use self::explanation::Explanation; -pub use self::multi_term_scorer::MultiTermAccumulator; \ No newline at end of file +pub use self::multi_term_scorer::MultiTermAccumulator; + + diff --git a/src/query/multi_term_explainer.rs b/src/query/multi_term_explainer.rs new file mode 100644 index 000000000..453e1561e --- /dev/null +++ b/src/query/multi_term_explainer.rs @@ -0,0 +1,36 @@ +use super::MultiTermAccumulator; +use super::MultiTermScorer; +use super::Explanation; + +pub struct MultiTermExplainer { + scorer: TScorer, + vals: Vec<(usize, u32, u32)>, +} + +impl MultiTermExplainer { + pub fn explain_score(&self,) -> Explanation { + self.scorer.explain(&self.vals) + } +} + +impl From for MultiTermExplainer { + fn from(multi_term_scorer: TScorer) -> MultiTermExplainer { + MultiTermExplainer { + scorer: multi_term_scorer, + vals: Vec::new(), + } + } +} + +impl MultiTermAccumulator for MultiTermExplainer { + fn update(&mut self, term_ord: usize, term_freq: u32, fieldnorm: u32) { + self.vals.push((term_ord, term_freq, fieldnorm)); + self.scorer.update(term_ord, term_freq, fieldnorm); + } + + fn clear(&mut self,) { + self.vals.clear(); + self.scorer.clear(); + } +} + diff --git a/src/query/multi_term_query.rs b/src/query/multi_term_query.rs index 250748e36..dfbb80fb5 100644 --- a/src/query/multi_term_query.rs +++ b/src/query/multi_term_query.rs @@ -7,13 +7,12 @@ use core::searcher::Searcher; use collector::Collector; use SegmentLocalId; use core::SegmentReader; -use query::MultiTermExplainScorer; +use query::MultiTermExplainer; use postings::SegmentPostings; use postings::UnionPostings; use postings::DocSet; use query::TfIdfScorer; use postings::SkipResult; -use fastfield::U32FastFieldReader; use ScoredDoc; use query::Scorer; use query::MultiTermAccumulator; @@ -33,12 +32,14 @@ impl Query for MultiTermQuery { searcher: &Searcher, doc_address: &DocAddress) -> Result { let segment_reader = &searcher.segments()[doc_address.segment_ord() as usize]; - let multi_term_scorer = MultiTermExplainScorer::from(self.scorer(searcher)); + let multi_term_scorer = MultiTermExplainer::from(self.scorer(searcher)); let mut timer_tree = TimerTree::new(); - let mut postings = self.search_segment( + let mut postings = try!( + self.search_segment( segment_reader, multi_term_scorer, - timer_tree.open("explain")); + timer_tree.open("explain")) + ); match postings.skip_next(doc_address.doc()) { SkipResult::Reached => { let scorer = postings.scorer(); @@ -67,10 +68,12 @@ impl Query for MultiTermQuery { let _ = segment_search_timer.open("set_segment"); try!(collector.set_segment(segment_ord as SegmentLocalId, &segment_reader)); } - let mut postings = self.search_segment( + let mut postings = try!( + self.search_segment( segment_reader, multi_term_scorer.clone(), - segment_search_timer.open("get_postings")); + segment_search_timer.open("get_postings")) + ); { let _collection_timer = segment_search_timer.open("collection"); while postings.advance() { @@ -123,21 +126,26 @@ impl MultiTermQuery { } } - fn search_segment<'a, 'b, TScorer: MultiTermAccumulator>(&'b self, reader: &'b SegmentReader, multi_term_scorer: TScorer, mut timer: OpenTimer<'a>) -> UnionPostings { - let mut segment_postings: Vec = Vec::with_capacity(self.terms.len()); - let mut fieldnorms_readers: Vec = Vec::with_capacity(self.terms.len()); + fn search_segment<'a, 'b, TScorer: MultiTermAccumulator>( + &'b self, + reader: &'b SegmentReader, + multi_term_scorer: TScorer, + mut timer: OpenTimer<'a>) -> Result> { + let mut postings_and_fieldnorms = Vec::with_capacity(self.num_terms()); { let mut decode_timer = timer.open("decode_all"); for term in &self.terms { let _decode_one_timer = decode_timer.open("decode_one"); - reader.read_postings(term) - .map(|postings| { + match reader.read_postings(term) { + Some(postings) => { let field = term.get_field(); - fieldnorms_readers.push(reader.get_fieldnorms_reader(field).unwrap()); - segment_postings.push(postings); - }); + let fieldnorm_reader = try!(reader.get_fieldnorms_reader(field)); + postings_and_fieldnorms.push((postings, fieldnorm_reader)); + } + None => {} + } } } - UnionPostings::new(fieldnorms_readers, segment_postings, multi_term_scorer) + Ok(UnionPostings::new(postings_and_fieldnorms, multi_term_scorer)) } } diff --git a/src/query/multi_term_scorer.rs b/src/query/multi_term_scorer.rs index ca52e3e3d..00e5ecb5e 100644 --- a/src/query/multi_term_scorer.rs +++ b/src/query/multi_term_scorer.rs @@ -9,158 +9,3 @@ pub trait MultiTermAccumulator { pub trait MultiTermScorer: Scorer + MultiTermAccumulator { fn explain(&self, vals: &Vec<(usize, u32, u32)>) -> Explanation; } - -#[derive(Clone)] -pub struct TfIdfScorer { - coords: Vec, - idf: Vec, - score: f32, - num_fields: usize, - term_names: Option>, //< only here for explain -} - -pub struct MultiTermExplainScorer { - scorer: TScorer, - vals: Vec<(usize, u32, u32)>, -} - -impl MultiTermExplainScorer { - pub fn explain_score(&self,) -> Explanation { - self.scorer.explain(&self.vals) - } -} - -impl From for MultiTermExplainScorer { - fn from(multi_term_scorer: TScorer) -> MultiTermExplainScorer { - MultiTermExplainScorer { - scorer: multi_term_scorer, - vals: Vec::new(), - } - } -} - -impl MultiTermAccumulator for MultiTermExplainScorer { - fn update(&mut self, term_ord: usize, term_freq: u32, fieldnorm: u32) { - self.vals.push((term_ord, term_freq, fieldnorm)); - self.scorer.update(term_ord, term_freq, fieldnorm); - } - fn clear(&mut self,) { - self.vals.clear(); - self.scorer.clear(); - } -} - -impl TfIdfScorer { - pub fn new(coords: Vec, idf: Vec) -> TfIdfScorer { - TfIdfScorer { - coords: coords, - idf: idf, - score: 0f32, - num_fields: 0, - term_names: None, - } - } - - fn coord(&self,) -> f32 { - self.coords[self.num_fields] - } - - pub fn set_term_names(&mut self, term_names: Vec) { - self.term_names = Some(term_names); - } - - fn term_name(&self, ord: usize) -> String { - match &self.term_names { - &Some(ref term_names_vec) => term_names_vec[ord].clone(), - &None => format!("Field({})", ord) - } - - } - - fn term_score(&self, term_ord: usize, term_freq: u32, field_norm: u32) -> f32 { - (term_freq as f32 / field_norm as f32).sqrt() * self.idf[term_ord] - } -} - -impl Scorer for TfIdfScorer { - fn score(&self, ) -> f32 { - self.score * self.coord() - } -} - -impl MultiTermScorer for TfIdfScorer { - - fn explain(&self, vals: &Vec<(usize, u32, u32)>) -> Explanation { - let score = self.score(); - let mut explanation = Explanation::with_val(score); - let formula_components: Vec = vals.iter() - .map(|&(ord, _, _)| ord) - .map(|ord| format!("", self.term_name(ord))) - .collect(); - let formula = format!(" * ({})", formula_components.join(" + ")); - explanation.set_formula(&formula); - for &(ord, term_freq, field_norm) in vals.iter() { - let term_score = self.term_score(ord, term_freq, field_norm); - let term_explanation = explanation.add_child(&self.term_name(ord), term_score); - term_explanation.set_formula(" sqrt( / ) * "); - } - explanation - } -} - -impl MultiTermAccumulator for TfIdfScorer { - - fn update(&mut self, term_ord: usize, term_freq: u32, fieldnorm: u32) { - assert!(term_freq != 0u32); - self.score += self.term_score(term_ord, term_freq, fieldnorm); - self.num_fields += 1; - } - - fn clear(&mut self,) { - self.score = 0f32; - self.num_fields = 0; - } -} - - -#[cfg(test)] -mod tests { - - use super::*; - use query::Scorer; - - - fn abs_diff(left: f32, right: f32) -> f32 { - (right - left).abs() - } - - #[test] - pub fn test_multiterm_scorer() { - let mut tfidf_scorer = TfIdfScorer::new(vec!(0f32, 1f32, 2f32), vec!(1f32, 4f32)); - { - tfidf_scorer.update(0, 1, 1); - assert!(abs_diff(tfidf_scorer.score(), 1f32) < 0.001f32); - tfidf_scorer.clear(); - - } - { - tfidf_scorer.update(1, 1, 1); - assert_eq!(tfidf_scorer.score(), 4f32); - tfidf_scorer.clear(); - } - { - tfidf_scorer.update(0, 2, 1); - assert!(abs_diff(tfidf_scorer.score(), 1.4142135) < 0.001f32); - tfidf_scorer.clear(); - } - { - tfidf_scorer.update(0, 1, 1); - tfidf_scorer.update(1, 1, 1); - assert_eq!(tfidf_scorer.score(), 10f32); - tfidf_scorer.clear(); - } - - - } - -} diff --git a/src/query/query.rs b/src/query/query.rs index 841abca25..a3f9a447f 100644 --- a/src/query/query.rs +++ b/src/query/query.rs @@ -15,8 +15,5 @@ pub trait Query { fn explain( &self, searcher: &Searcher, - doc_address: &DocAddress) -> Result { - // TODO check that the document is there or return an error. - panic!("Not implemented"); - } + doc_address: &DocAddress) -> Result; } diff --git a/src/query/tfidf.rs b/src/query/tfidf.rs new file mode 100644 index 000000000..6dd2b9f58 --- /dev/null +++ b/src/query/tfidf.rs @@ -0,0 +1,132 @@ +use super::MultiTermAccumulator; +use super::Scorer; +use super::MultiTermScorer; +use super::Explanation; + +#[derive(Clone)] +pub struct TfIdfScorer { + coords: Vec, + idf: Vec, + score: f32, + num_fields: usize, + term_names: Option>, //< only here for explain +} + +impl MultiTermAccumulator for TfIdfScorer { + + #[inline(always)] + fn update(&mut self, term_ord: usize, term_freq: u32, fieldnorm: u32) { + assert!(term_freq != 0u32); + self.score += self.term_score(term_ord, term_freq, fieldnorm); + self.num_fields += 1; + } + + #[inline(always)] + fn clear(&mut self,) { + self.score = 0f32; + self.num_fields = 0; + } +} + +impl TfIdfScorer { + pub fn new(coords: Vec, idf: Vec) -> TfIdfScorer { + TfIdfScorer { + coords: coords, + idf: idf, + score: 0f32, + num_fields: 0, + term_names: None, + } + } + + #[inline(always)] + fn coord(&self,) -> f32 { + self.coords[self.num_fields] + } + + pub fn set_term_names(&mut self, term_names: Vec) { + self.term_names = Some(term_names); + } + + fn term_name(&self, ord: usize) -> String { + match &self.term_names { + &Some(ref term_names_vec) => term_names_vec[ord].clone(), + &None => format!("Field({})", ord) + } + } + + #[inline(always)] + fn term_score(&self, term_ord: usize, term_freq: u32, field_norm: u32) -> f32 { + (term_freq as f32 / field_norm as f32).sqrt() * self.idf[term_ord] + } +} + +impl Scorer for TfIdfScorer { + #[inline(always)] + fn score(&self, ) -> f32 { + self.score * self.coord() + } +} + +impl MultiTermScorer for TfIdfScorer { + fn explain(&self, vals: &Vec<(usize, u32, u32)>) -> Explanation { + let score = self.score(); + let mut explanation = Explanation::with_val(score); + let formula_components: Vec = vals.iter() + .map(|&(ord, _, _)| ord) + .map(|ord| format!("", self.term_name(ord))) + .collect(); + let formula = format!(" * ({})", formula_components.join(" + ")); + explanation.set_formula(&formula); + for &(ord, term_freq, field_norm) in vals.iter() { + let term_score = self.term_score(ord, term_freq, field_norm); + let term_explanation = explanation.add_child(&self.term_name(ord), term_score); + term_explanation.set_formula(" sqrt( / ) * "); + } + explanation + } +} + + + +#[cfg(test)] +mod tests { + + use super::*; + use query::Scorer; + use query::MultiTermAccumulator; + + fn abs_diff(left: f32, right: f32) -> f32 { + (right - left).abs() + } + + #[test] + pub fn test_multiterm_scorer() { + let mut tfidf_scorer = TfIdfScorer::new(vec!(0f32, 1f32, 2f32), vec!(1f32, 4f32)); + { + tfidf_scorer.update(0, 1, 1); + assert!(abs_diff(tfidf_scorer.score(), 1f32) < 0.001f32); + tfidf_scorer.clear(); + + } + { + tfidf_scorer.update(1, 1, 1); + assert_eq!(tfidf_scorer.score(), 4f32); + tfidf_scorer.clear(); + } + { + tfidf_scorer.update(0, 2, 1); + assert!(abs_diff(tfidf_scorer.score(), 1.4142135) < 0.001f32); + tfidf_scorer.clear(); + } + { + tfidf_scorer.update(0, 1, 1); + tfidf_scorer.update(1, 1, 1); + assert_eq!(tfidf_scorer.score(), 10f32); + tfidf_scorer.clear(); + } + + + } + +} \ No newline at end of file