From 0ca78328759b1c85ca791fe849f7f3c139295012 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Sat, 29 Oct 2016 13:00:08 +0900 Subject: [PATCH 01/19] issue/50 First try The dynamic dispatch is not hurting perf a little. The "the" query goes from 158 ms -> 188 ms. --- src/lib.rs | 1 + src/query/mod.rs | 4 +- src/query/multi_term_query.rs | 136 +++++++++++----------------------- src/query/phrase_query.rs | 2 +- src/query/query.rs | 45 ++++++++--- src/query/query_parser.rs | 25 ++++--- src/query/weight.rs | 11 +++ 7 files changed, 105 insertions(+), 119 deletions(-) create mode 100644 src/query/weight.rs diff --git a/src/lib.rs b/src/lib.rs index d30bc6976..9106fe8a5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,7 @@ #![allow(unknown_lints)] // for the clippy lint options #![allow(module_inception)] +#![feature(box_syntax)] #![feature(optin_builtin_traits)] #![feature(conservative_impl_trait)] #![cfg_attr(test, feature(test))] diff --git a/src/query/mod.rs b/src/query/mod.rs index 12af7d35c..7a516d61f 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -14,6 +14,7 @@ mod tfidf; mod occur; mod daat_multiterm_scorer; mod similarity; +mod weight; pub use self::similarity::Similarity; @@ -29,4 +30,5 @@ pub use self::scorer::Scorer; pub use self::query_parser::QueryParser; pub use self::explanation::Explanation; pub use self::multi_term_accumulator::MultiTermAccumulator; -pub use self::query_parser::ParsingError; \ No newline at end of file +pub use self::query_parser::ParsingError; +pub use self::weight::Weight; \ No newline at end of file diff --git a/src/query/multi_term_query.rs b/src/query/multi_term_query.rs index 24026e6f8..aa0287e76 100644 --- a/src/query/multi_term_query.rs +++ b/src/query/multi_term_query.rs @@ -1,30 +1,49 @@ use Result; +use super::Weight; use Error; use schema::Term; use query::Query; -use common::TimerTree; -use common::OpenTimer; use core::searcher::Searcher; -use collector::Collector; -use SegmentLocalId; use core::SegmentReader; -use query::SimilarityExplainer; -use postings::SegmentPostings; -use postings::DocSet; use query::TfIdf; -use postings::SkipResult; -use ScoredDoc; use query::Scorer; -use query::MultiTermAccumulator; -use DocAddress; -use query::Explanation; use query::occur::Occur; use postings::SegmentPostingsOption; use query::DAATMultiTermScorer; + +struct MultiTermWeight { + query: MultiTermQuery, + similitude: TfIdf, +} + + +impl Weight for MultiTermWeight { + + + fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result> { + + let mut postings_and_fieldnorms = Vec::with_capacity(self.query.num_terms()); + { + for &(occur, ref term) in &self.query.occur_terms { + if let Some(postings) = reader.read_postings(term, SegmentPostingsOption::Freq) { + let field = term.field(); + let fieldnorm_reader = try!(reader.get_fieldnorms_reader(field)); + postings_and_fieldnorms.push((occur, postings, fieldnorm_reader)); + } + } + } + if postings_and_fieldnorms.len() > 64 { + // TODO putting the SHOULD at the end of the list should push the limit. + return Err(Error::InvalidArgument(String::from("Limit of 64 terms was exceeded."))); + } + Ok(box DAATMultiTermScorer::new(postings_and_fieldnorms, self.similitude.clone())) + } +} + /// Query involving one or more terms. -#[derive(Eq, PartialEq, Debug)] +#[derive(Eq, Clone, PartialEq, Debug)] pub struct MultiTermQuery { occur_terms: Vec<(Occur, Term)>, } @@ -64,32 +83,6 @@ impl MultiTermQuery { tfidf.set_term_names(term_names); tfidf } - - - /// Search the segment. - fn search_segment<'a, 'b, TAccumulator: MultiTermAccumulator>( - &'b self, - reader: &'b SegmentReader, - accumulator: TAccumulator, - mut timer: OpenTimer<'a>) -> Result> { - let mut postings_and_fieldnorms = Vec::with_capacity(self.num_terms()); - { - let mut decode_timer = timer.open("decode_all"); - for &(occur, ref term) in &self.occur_terms { - let _decode_one_timer = decode_timer.open("decode_one"); - if let Some(postings) = reader.read_postings(term, SegmentPostingsOption::Freq) { - let field = term.field(); - let fieldnorm_reader = try!(reader.get_fieldnorms_reader(field)); - postings_and_fieldnorms.push((occur, postings, fieldnorm_reader)); - } - } - } - if postings_and_fieldnorms.len() > 64 { - // TODO putting the SHOULD at the end of the list should push the limit. - return Err(Error::InvalidArgument(String::from("Limit of 64 terms was exceeded."))); - } - Ok(DAATMultiTermScorer::new(postings_and_fieldnorms, accumulator)) - } } @@ -114,62 +107,17 @@ impl From> for MultiTermQuery { } impl Query for MultiTermQuery { - - fn explain( - &self, - searcher: &Searcher, - doc_address: &DocAddress) -> Result { - let segment_reader = searcher.segment_reader(doc_address.segment_ord() as usize); - let similitude = SimilarityExplainer::from(self.similitude(searcher)); - let mut timer_tree = TimerTree::default(); - let mut postings = try!( - self.search_segment( - segment_reader, - similitude, - timer_tree.open("explain")) - ); - Ok(match postings.skip_next(doc_address.doc()) { - SkipResult::Reached => { - let scorer = postings.scorer(); - scorer.explain_score() - } - _ => { - let mut explanation = Explanation::with_val(0f32); - explanation.description(&format!("Failed to run explain: the document {:?} does not match", doc_address)); - explanation - } - }) + + + fn weight(&self, searcher: &Searcher) -> Result> { + let similitude = self.similitude(searcher); + Ok( + Box::new(MultiTermWeight { + query: self.clone(), + similitude: similitude + }) + ) } - fn search( - &self, - searcher: &Searcher, - collector: &mut C) -> Result { - let mut timer_tree = TimerTree::default(); - { - let mut search_timer = timer_tree.open("search"); - for (segment_ord, segment_reader) in searcher.segment_readers().iter().enumerate() { - let mut segment_search_timer = search_timer.open("segment_search"); - { - let _ = segment_search_timer.open("set_segment"); - try!(collector.set_segment(segment_ord as SegmentLocalId, &segment_reader)); - } - let mut postings = try!( - self.search_segment( - segment_reader, - self.similitude(searcher), - segment_search_timer.open("get_postings")) - ); - { - let _collection_timer = segment_search_timer.open("collection"); - while postings.advance() { - let scored_doc = ScoredDoc(postings.score(), postings.doc()); - collector.collect(scored_doc); - } - } - } - } - Ok(timer_tree) - } } diff --git a/src/query/phrase_query.rs b/src/query/phrase_query.rs index a0ce11748..81e1418f6 100644 --- a/src/query/phrase_query.rs +++ b/src/query/phrase_query.rs @@ -17,7 +17,7 @@ pub struct PhraseQuery { impl Query for PhraseQuery { - fn search(&self, searcher: &Searcher, collector: &mut C) -> io::Result { + fn search(&self, searcher: &Searcher, collector: &mut Collector) -> io::Result { let mut timer_tree = TimerTree::default(); { let mut search_timer = timer_tree.open("search"); diff --git a/src/query/query.rs b/src/query/query.rs index 5ef8e5823..6139cac05 100644 --- a/src/query/query.rs +++ b/src/query/query.rs @@ -2,8 +2,9 @@ use Result; use collector::Collector; use core::searcher::Searcher; use common::TimerTree; -use DocAddress; -use query::Explanation; +use SegmentLocalId; +use ScoredDoc; +use super::Weight; /// Queries represent the query of the user, and are in charge @@ -11,16 +12,38 @@ use query::Explanation; /// sent to the collector, as well as the way to score the /// documents. pub trait Query { - - /// Perform the search operation - fn search( - &self, - searcher: &Searcher, - collector: &mut C) -> Result; + + + fn weight(&self, searcher: &Searcher) -> Result>; - /// Explain the score of a specific document - fn explain( + /// Perform the search operation + fn search( &self, searcher: &Searcher, - doc_address: &DocAddress) -> Result; + collector: &mut Collector) -> Result { + + let mut timer_tree = TimerTree::default(); + let weight = try!(self.weight(searcher)); + + { + let mut search_timer = timer_tree.open("search"); + for (segment_ord, segment_reader) in searcher.segment_readers().iter().enumerate() { + let mut segment_search_timer = search_timer.open("segment_search"); + { + let _ = segment_search_timer.open("set_segment"); + try!(collector.set_segment(segment_ord as SegmentLocalId, &segment_reader)); + } + let mut scorer = try!(weight.scorer(segment_reader)); + { + let _collection_timer = segment_search_timer.open("collection"); + while scorer.advance() { + let scored_doc = ScoredDoc(scorer.score(), scorer.doc()); + collector.collect(scored_doc); + } + } + } + } + Ok(timer_tree) + } + } diff --git a/src/query/query_parser.rs b/src/query/query_parser.rs index 89876825d..ac1419acb 100644 --- a/src/query/query_parser.rs +++ b/src/query/query_parser.rs @@ -1,14 +1,13 @@ -use Result as tantivy_Error; +use Result as tantivy_Result; use combine::*; use collector::Collector; +use super::Weight; use core::searcher::Searcher; use common::TimerTree; use query::{Query, MultiTermQuery}; use schema::{Schema, FieldType, Term, Field}; use analyzer::SimpleTokenizer; use analyzer::StreamingIterator; -use DocAddress; -use query::Explanation; use query::Occur; @@ -168,7 +167,17 @@ impl QueryParser { impl Query for StandardQuery { - fn search(&self, searcher: &Searcher, collector: &mut C) -> tantivy_Error { + + + fn weight(&self, searcher: &Searcher) -> tantivy_Result> { + match *self { + StandardQuery::MultiTerm(ref q) => { + q.weight(searcher) + } + } + } + + fn search(&self, searcher: &Searcher, collector: &mut Collector) -> tantivy_Result { match *self { StandardQuery::MultiTerm(ref q) => { q.search(searcher, collector) @@ -176,14 +185,6 @@ impl Query for StandardQuery { } } - fn explain( - &self, - searcher: &Searcher, - doc_address: &DocAddress) -> tantivy_Error { - match *self { - StandardQuery::MultiTerm(ref q) => q.explain(searcher, doc_address) - } - } } diff --git a/src/query/weight.rs b/src/query/weight.rs new file mode 100644 index 000000000..27a7afd65 --- /dev/null +++ b/src/query/weight.rs @@ -0,0 +1,11 @@ +use super::Scorer; +use Result; +use core::SegmentReader; + +pub trait Weight { + + + fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result>; + + +} From 67b1071412a13e69cd49eb40ec6af7998b8c0f26 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Sat, 29 Oct 2016 14:41:50 +0900 Subject: [PATCH 02/19] issue/50 Moved segment local collection to the DocSet object. --- src/query/query.rs | 5 +---- src/query/scorer.rs | 10 +++++++++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/query/query.rs b/src/query/query.rs index 6139cac05..bb7060604 100644 --- a/src/query/query.rs +++ b/src/query/query.rs @@ -36,10 +36,7 @@ pub trait Query { let mut scorer = try!(weight.scorer(segment_reader)); { let _collection_timer = segment_search_timer.open("collection"); - while scorer.advance() { - let scored_doc = ScoredDoc(scorer.score(), scorer.doc()); - collector.collect(scored_doc); - } + scorer.collect(collector); } } } diff --git a/src/query/scorer.rs b/src/query/scorer.rs index 7e9f17424..21d57ee71 100644 --- a/src/query/scorer.rs +++ b/src/query/scorer.rs @@ -1,5 +1,6 @@ use DocSet; - +use ScoredDoc; +use collector::Collector; /// Scored `DocSet` pub trait Scorer: DocSet { @@ -8,6 +9,13 @@ pub trait Scorer: DocSet { /// /// This method will perform a bit of computation and is not cached. fn score(&self,) -> f32; + + fn collect(&mut self, collector: &mut Collector) { + while self.advance() { + let scored_doc = ScoredDoc(self.score(), self.doc()); + collector.collect(scored_doc); + } + } } From 332be6d581dd23a40cdd2c04675f44ea0ba8a065 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Sat, 29 Oct 2016 15:09:46 +0900 Subject: [PATCH 03/19] issue/50 Removed ScoredDoc --- src/collector/chained_collector.rs | 18 ++++++++--------- src/collector/count_collector.rs | 11 +++++----- src/collector/mod.rs | 21 ++++++++++---------- src/collector/multi_collector.rs | 14 ++++++------- src/collector/top_collector.rs | 32 +++++++++++++++--------------- src/lib.rs | 16 --------------- src/query/multi_term_query.rs | 3 +-- src/query/query.rs | 1 - src/query/scorer.rs | 6 +----- 9 files changed, 50 insertions(+), 72 deletions(-) diff --git a/src/collector/chained_collector.rs b/src/collector/chained_collector.rs index e52164f6a..5840eb775 100644 --- a/src/collector/chained_collector.rs +++ b/src/collector/chained_collector.rs @@ -2,7 +2,8 @@ use collector::Collector; use SegmentLocalId; use SegmentReader; use std::io; -use ScoredDoc; +use DocId; +use Score; /// Collector that does nothing. @@ -15,7 +16,7 @@ impl Collector for DoNothingCollector { Ok(()) } #[inline] - fn collect(&mut self, _: ScoredDoc) {} + fn collect(&mut self, _doc: DocId, _score: Score) {} } /// Zero-cost abstraction used to collect on multiple collectors. @@ -43,9 +44,9 @@ impl Collector for ChainedCollector ChainedCollector { mod tests { use super::*; - use ScoredDoc; use collector::{Collector, CountCollector, TopCollector}; #[test] @@ -73,9 +73,9 @@ mod tests { let mut collectors = chain() .push(&mut top_collector) .push(&mut count_collector); - collectors.collect(ScoredDoc(0.2, 1)); - collectors.collect(ScoredDoc(0.1, 2)); - collectors.collect(ScoredDoc(0.5, 3)); + collectors.collect(1, 0.2); + collectors.collect(2, 0.1); + collectors.collect(3, 0.5); } assert_eq!(count_collector.count(), 3); assert!(top_collector.at_capacity()); diff --git a/src/collector/count_collector.rs b/src/collector/count_collector.rs index 44d547ec3..8a9014a25 100644 --- a/src/collector/count_collector.rs +++ b/src/collector/count_collector.rs @@ -1,6 +1,7 @@ use std::io; use super::Collector; -use ScoredDoc; +use DocId; +use Score; use SegmentReader; use SegmentLocalId; @@ -31,7 +32,7 @@ impl Collector for CountCollector { Ok(()) } - fn collect(&mut self, _: ScoredDoc) { + fn collect(&mut self, _: DocId, _: Score) { self.count += 1; } } @@ -41,16 +42,14 @@ mod tests { use super::*; use test::Bencher; - use ScoredDoc; use collector::Collector; #[bench] fn build_collector(b: &mut Bencher) { b.iter(|| { let mut count_collector = CountCollector::default(); - let docs: Vec = (0..1_000_000).collect(); - for doc in docs { - count_collector.collect(ScoredDoc(1f32, doc)); + for doc in 0..1_000_000 { + count_collector.collect(doc, 1f32); } count_collector.count() }); diff --git a/src/collector/mod.rs b/src/collector/mod.rs index 683b7eb1c..84bc38485 100644 --- a/src/collector/mod.rs +++ b/src/collector/mod.rs @@ -1,6 +1,7 @@ use SegmentReader; use SegmentLocalId; -use ScoredDoc; +use DocId; +use Score; use std::io; mod count_collector; @@ -49,7 +50,7 @@ pub trait Collector { /// on this segment. fn set_segment(&mut self, segment_local_id: SegmentLocalId, segment: &SegmentReader) -> io::Result<()>; /// The query pushes the scored document to the collector via this method. - fn collect(&mut self, scored_doc: ScoredDoc); + fn collect(&mut self, doc: DocId, score: Score); } @@ -58,8 +59,8 @@ impl<'a, C: Collector> Collector for &'a mut C { (*self).set_segment(segment_local_id, segment) } /// The query pushes the scored document to the collector via this method. - fn collect(&mut self, scored_doc: ScoredDoc) { - (*self).collect(scored_doc); + fn collect(&mut self, doc: DocId, score: Score) { + (*self).collect(doc, score); } } @@ -69,8 +70,8 @@ pub mod tests { use super::*; use test::Bencher; - use ScoredDoc; use DocId; + use Score; use core::SegmentReader; use std::io; use SegmentLocalId; @@ -112,8 +113,8 @@ pub mod tests { Ok(()) } - fn collect(&mut self, scored_doc: ScoredDoc) { - self.docs.push(scored_doc.doc() + self.offset); + fn collect(&mut self, doc: DocId, _score: Score) { + self.docs.push(doc + self.offset); } } @@ -150,8 +151,8 @@ pub mod tests { Ok(()) } - fn collect(&mut self, scored_doc: ScoredDoc) { - let val = self.ff_reader.as_ref().unwrap().get(scored_doc.doc()); + fn collect(&mut self, doc: DocId, _score: Score) { + let val = self.ff_reader.as_ref().unwrap().get(doc); self.vals.push(val); } } @@ -163,7 +164,7 @@ pub mod tests { let mut count_collector = CountCollector::default(); let docs: Vec = (0..1_000_000).collect(); for doc in docs { - count_collector.collect(ScoredDoc(1f32, doc)); + count_collector.collect(doc, 1f32); } count_collector.count() }); diff --git a/src/collector/multi_collector.rs b/src/collector/multi_collector.rs index 92958018d..6ce999e80 100644 --- a/src/collector/multi_collector.rs +++ b/src/collector/multi_collector.rs @@ -1,6 +1,7 @@ use std::io; use super::Collector; -use ScoredDoc; +use DocId; +use Score; use SegmentReader; use SegmentLocalId; @@ -31,9 +32,9 @@ impl<'a> Collector for MultiCollector<'a> { Ok(()) } - fn collect(&mut self, scored_doc: ScoredDoc) { + fn collect(&mut self, doc: DocId, score: Score) { for collector in &mut self.collectors { - collector.collect(scored_doc); + collector.collect(doc, score); } } } @@ -44,7 +45,6 @@ impl<'a> Collector for MultiCollector<'a> { mod tests { use super::*; - use ScoredDoc; use collector::{Collector, CountCollector, TopCollector}; #[test] @@ -53,9 +53,9 @@ mod tests { let mut count_collector = CountCollector::default(); { let mut collectors = MultiCollector::from(vec!(&mut top_collector, &mut count_collector)); - collectors.collect(ScoredDoc(0.2, 1)); - collectors.collect(ScoredDoc(0.1, 2)); - collectors.collect(ScoredDoc(0.5, 3)); + collectors.collect(1, 0.2); + collectors.collect(2, 0.1); + collectors.collect(3, 0.5); } assert_eq!(count_collector.count(), 3); assert!(top_collector.at_capacity()); diff --git a/src/collector/top_collector.rs b/src/collector/top_collector.rs index e7fd0d018..21c023caf 100644 --- a/src/collector/top_collector.rs +++ b/src/collector/top_collector.rs @@ -1,11 +1,11 @@ use std::io; use super::Collector; -use ScoredDoc; use SegmentReader; use SegmentLocalId; use DocAddress; use std::collections::BinaryHeap; use std::cmp::Ordering; +use DocId; use Score; // Rust heap is a max-heap and we need a min heap. @@ -13,6 +13,7 @@ use Score; struct GlobalScoredDoc { score: Score, doc_address: DocAddress + } impl PartialOrd for GlobalScoredDoc { @@ -109,20 +110,20 @@ impl Collector for TopCollector { Ok(()) } - fn collect(&mut self, scored_doc: ScoredDoc) { + fn collect(&mut self, doc: DocId, score: Score) { if self.at_capacity() { // It's ok to unwrap as long as a limit of 0 is forbidden. let limit_doc: GlobalScoredDoc = *self.heap.peek().expect("Top collector with size 0 is forbidden"); - if limit_doc.score < scored_doc.score() { + if limit_doc.score < score { let mut mut_head = self.heap.peek_mut().expect("Top collector with size 0 is forbidden"); - mut_head.score = scored_doc.score(); - mut_head.doc_address = DocAddress(self.segment_id, scored_doc.doc()); + mut_head.score = score; + mut_head.doc_address = DocAddress(self.segment_id, doc); } } else { let wrapped_doc = GlobalScoredDoc { - score: scored_doc.score(), - doc_address: DocAddress(self.segment_id, scored_doc.doc()) + score: score, + doc_address: DocAddress(self.segment_id, doc) }; self.heap.push(wrapped_doc); } @@ -135,7 +136,6 @@ impl Collector for TopCollector { mod tests { use super::*; - use ScoredDoc; use DocId; use Score; use collector::Collector; @@ -143,9 +143,9 @@ mod tests { #[test] fn test_top_collector_not_at_capacity() { let mut top_collector = TopCollector::with_limit(4); - top_collector.collect(ScoredDoc(0.8, 1)); - top_collector.collect(ScoredDoc(0.2, 3)); - top_collector.collect(ScoredDoc(0.3, 5)); + top_collector.collect(1, 0.8); + top_collector.collect(3, 0.2); + top_collector.collect(5, 0.3); assert!(!top_collector.at_capacity()); let score_docs: Vec<(Score, DocId)> = top_collector.score_docs() .into_iter() @@ -159,11 +159,11 @@ mod tests { #[test] fn test_top_collector_at_capacity() { let mut top_collector = TopCollector::with_limit(4); - top_collector.collect(ScoredDoc(0.8, 1)); - top_collector.collect(ScoredDoc(0.2, 3)); - top_collector.collect(ScoredDoc(0.3, 5)); - top_collector.collect(ScoredDoc(0.9, 7)); - top_collector.collect(ScoredDoc(-0.2, 9)); + top_collector.collect(1, 0.8); + top_collector.collect(3, 0.2); + top_collector.collect(5, 0.3); + top_collector.collect(7, 0.9); + top_collector.collect(9, -0.2); assert!(top_collector.at_capacity()); { let score_docs: Vec<(Score, DocId)> = top_collector diff --git a/src/lib.rs b/src/lib.rs index 9106fe8a5..e22386b7c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -139,22 +139,6 @@ impl DocAddress { } } -/// A scored doc is simply a couple `(score, doc_id)` -#[derive(Clone, Copy)] -pub struct ScoredDoc(Score, DocId); - -impl ScoredDoc { - - /// Returns the score - pub fn score(&self,) -> Score { - self.0 - } - - /// Returns the doc - pub fn doc(&self,) -> DocId { - self.1 - } -} /// `DocAddress` contains all the necessary information /// to identify a document given a `Searcher` object. diff --git a/src/query/multi_term_query.rs b/src/query/multi_term_query.rs index aa0287e76..f367669c4 100644 --- a/src/query/multi_term_query.rs +++ b/src/query/multi_term_query.rs @@ -107,8 +107,7 @@ impl From> for MultiTermQuery { } impl Query for MultiTermQuery { - - + fn weight(&self, searcher: &Searcher) -> Result> { let similitude = self.similitude(searcher); Ok( diff --git a/src/query/query.rs b/src/query/query.rs index bb7060604..869f1f41a 100644 --- a/src/query/query.rs +++ b/src/query/query.rs @@ -3,7 +3,6 @@ use collector::Collector; use core::searcher::Searcher; use common::TimerTree; use SegmentLocalId; -use ScoredDoc; use super::Weight; diff --git a/src/query/scorer.rs b/src/query/scorer.rs index 21d57ee71..fc6525800 100644 --- a/src/query/scorer.rs +++ b/src/query/scorer.rs @@ -1,5 +1,4 @@ use DocSet; -use ScoredDoc; use collector::Collector; /// Scored `DocSet` @@ -12,10 +11,7 @@ pub trait Scorer: DocSet { fn collect(&mut self, collector: &mut Collector) { while self.advance() { - let scored_doc = ScoredDoc(self.score(), self.doc()); - collector.collect(scored_doc); + collector.collect(self.doc(), self.score()); } } } - - From 767fa94d182a4df8097540b81e8a1730f0fefdca Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Sat, 29 Oct 2016 17:13:12 +0900 Subject: [PATCH 04/19] issue/50 Removed StandardQuery --- examples/simple_search.rs | 1 - src/query/multi_term_query.rs | 7 +++- src/query/query.rs | 7 ++-- src/query/query_parser.rs | 66 +++++++---------------------------- 4 files changed, 24 insertions(+), 57 deletions(-) diff --git a/examples/simple_search.rs b/examples/simple_search.rs index 42b0600c8..2ca533f1d 100644 --- a/examples/simple_search.rs +++ b/examples/simple_search.rs @@ -8,7 +8,6 @@ use tantivy::Index; use tantivy::schema::*; use tantivy::collector::TopCollector; use tantivy::query::QueryParser; -use tantivy::query::Query; fn main() { // Let's create a temporary directory for the diff --git a/src/query/multi_term_query.rs b/src/query/multi_term_query.rs index f367669c4..15d63bfdf 100644 --- a/src/query/multi_term_query.rs +++ b/src/query/multi_term_query.rs @@ -1,5 +1,6 @@ use Result; use super::Weight; +use std::any::Any; use Error; use schema::Term; use query::Query; @@ -107,7 +108,11 @@ impl From> for MultiTermQuery { } impl Query for MultiTermQuery { - + + fn as_any(&self) -> &Any { + self + } + fn weight(&self, searcher: &Searcher) -> Result> { let similitude = self.similitude(searcher); Ok( diff --git a/src/query/query.rs b/src/query/query.rs index 869f1f41a..0e27c1315 100644 --- a/src/query/query.rs +++ b/src/query/query.rs @@ -4,14 +4,17 @@ use core::searcher::Searcher; use common::TimerTree; use SegmentLocalId; use super::Weight; +use std::fmt; +use std::any::Any; /// Queries represent the query of the user, and are in charge /// of the logic defining the set of documents that should be /// sent to the collector, as well as the way to score the /// documents. -pub trait Query { - +pub trait Query: fmt::Debug { + + fn as_any(&self) -> &Any; fn weight(&self, searcher: &Searcher) -> Result>; diff --git a/src/query/query_parser.rs b/src/query/query_parser.rs index ac1419acb..3d0ee87ec 100644 --- a/src/query/query_parser.rs +++ b/src/query/query_parser.rs @@ -1,9 +1,4 @@ -use Result as tantivy_Result; use combine::*; -use collector::Collector; -use super::Weight; -use core::searcher::Searcher; -use common::TimerTree; use query::{Query, MultiTermQuery}; use schema::{Schema, FieldType, Term, Field}; use analyzer::SimpleTokenizer; @@ -60,23 +55,6 @@ pub struct QueryParser { } -/// The `QueryParser` returns a `StandardQuery`. -#[derive(Eq, PartialEq, Debug)] -pub enum StandardQuery { - MultiTerm(MultiTermQuery), -} - -impl StandardQuery { - /// Number of terms involved in the query. - pub fn num_terms(&self,) -> usize { - match *self { - StandardQuery::MultiTerm(ref q) => { - q.num_terms() - } - } - } -} - impl QueryParser { /// Creates a `QueryParser` @@ -141,7 +119,7 @@ impl QueryParser { /// /// Implementing a lenient mode for this query parser is tracked /// in [Issue 5](https://github.com/fulmicoton/tantivy/issues/5) - pub fn parse_query(&self, query: &str) -> Result { + pub fn parse_query(&self, query: &str) -> Result, ParsingError> { match parser(query_language).parse(query.trim()) { Ok(literals) => { let mut terms_result: Vec<(Occur, Term)> = Vec::new(); @@ -153,9 +131,7 @@ impl QueryParser { .map(|term| (occur, term) )); } Ok( - StandardQuery::MultiTerm( - MultiTermQuery::from(terms_result) - ) + box MultiTermQuery::from(terms_result) ) } Err(_) => { @@ -166,26 +142,6 @@ impl QueryParser { } -impl Query for StandardQuery { - - - fn weight(&self, searcher: &Searcher) -> tantivy_Result> { - match *self { - StandardQuery::MultiTerm(ref q) => { - q.weight(searcher) - } - } - } - - fn search(&self, searcher: &Searcher, collector: &mut Collector) -> tantivy_Result { - match *self { - StandardQuery::MultiTerm(ref q) => { - q.search(searcher, collector) - } - } - } - -} fn compute_terms(field: Field, text: &str) -> Vec { @@ -325,6 +281,10 @@ mod test { assert!(query_parser.parse("f:@e!e").is_err()); } + // fn extract(query_parser: &QueryParser, q: &str) -> T { + // query_parser.parse_query(q).unwrap().as_any().downcast_ref::().unwrap(), + // } + #[test] pub fn test_query_parser() { let mut schema_builder = SchemaBuilder::default(); @@ -335,9 +295,9 @@ mod test { assert!(query_parser.parse_query("a:b").is_err()); { let terms = vec!(Term::from_field_text(title_field, "abctitle")); - let query = StandardQuery::MultiTerm(MultiTermQuery::from(terms)); + let query = MultiTermQuery::from(terms); assert_eq!( - query_parser.parse_query("title:abctitle").unwrap(), + *query_parser.parse_query("title:abctitle").unwrap().as_any().downcast_ref::().unwrap(), query ); } @@ -346,21 +306,21 @@ mod test { Term::from_field_text(text_field, "abctitle"), Term::from_field_text(author_field, "abctitle"), ); - let query = StandardQuery::MultiTerm(MultiTermQuery::from(terms)); + let query = MultiTermQuery::from(terms); assert_eq!( - query_parser.parse_query("abctitle").unwrap(), + *query_parser.parse_query("abctitle").unwrap().as_any().downcast_ref::().unwrap(), query ); } { let terms = vec!(Term::from_field_text(title_field, "abctitle")); - let query = StandardQuery::MultiTerm(MultiTermQuery::from(terms)); + let query = MultiTermQuery::from(terms); assert_eq!( - query_parser.parse_query("title:abctitle ").unwrap(), + *query_parser.parse_query("title:abctitle ").unwrap().as_any().downcast_ref::().unwrap(), query ); assert_eq!( - query_parser.parse_query(" title:abctitle").unwrap(), + *query_parser.parse_query(" title:abctitle").unwrap().as_any().downcast_ref::().unwrap(), query ); } From 8b5f7af688fe442c7f30447473034aafbc242361 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Sun, 30 Oct 2016 10:06:51 +0900 Subject: [PATCH 05/19] issue/50 Added documentation for the query object. --- src/query/query.rs | 50 ++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 44 insertions(+), 6 deletions(-) diff --git a/src/query/query.rs b/src/query/query.rs index 0e27c1315..3b15f3010 100644 --- a/src/query/query.rs +++ b/src/query/query.rs @@ -8,17 +8,55 @@ use std::fmt; use std::any::Any; -/// Queries represent the query of the user, and are in charge -/// of the logic defining the set of documents that should be -/// sent to the collector, as well as the way to score the -/// documents. +/// Query trait are in charge of defining : +/// +/// - a set of documents +/// - a way to score these documents +/// +/// When performing a [search](#method.search), these documents will then +/// be pushed to a [Collector](../collector/trait.Collector.html), +/// which will in turn be in charge of deciding what to do with them. +/// +/// Concretely, this scored docset is represented by the +/// [`Scorer`](./trait.Scorer.html) trait. +/// +/// Because our index is actually split into segments, the +/// query does not actually directly creates `DocSet` object. +/// Instead, the query creates a [`Weight`](./trait.Weight.html) +/// object for a given searcher. +/// +/// The weight object, in turn, makes it possible to create +/// a scorer for a specific [`SegmentReader`](../struct.SegmentReader.html). +/// +/// So to sum it up : +/// - a `Query` is recipe to define a set of documents as well the way to score them. +/// - a `Weight` is this recipe tied to a specific `Searcher`. It may for instance +/// hold statistics about the different term of the query. It is created by the query. +/// - a `Scorer` is a cursor over the set of matching documents, for a specific +/// [`SegmentReader`](../struct.SegmentReader.html). It is created by the [`Weight`](./trait.Weight.html). +/// +/// When implementing a new type of `Query`, it is normal to implement a +/// dedicated `Query`, `Weight` and `Scorer`. pub trait Query: fmt::Debug { + /// Used to make it possible to cast Box + /// into a specific type. This is mostly useful for unit tests. fn as_any(&self) -> &Any; + /// Create the weight associated to a query. + /// + /// See [Weight](./trait.Weight.html). fn weight(&self, searcher: &Searcher) -> Result>; - - /// Perform the search operation + + /// Search works as follows : + /// + /// First the weight object associated to the query is created. + /// + /// Then, the query loops over the segments and for each segment : + /// - setup the collector and informs it that the segment being processed has changed. + /// - creates a `Scorer` object associated for this segment + /// - iterate throw the matched documents and push them to the collector. + /// fn search( &self, searcher: &Searcher, From fa78baf278e38e5fc534f92a292e301a964853a3 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Sun, 30 Oct 2016 17:03:00 +0900 Subject: [PATCH 06/19] issue/50 Introducing Term Query / Boolean Query --- src/query/boolean_query/boolean_clause.rs | 18 +++++++++ src/query/boolean_query/boolean_query.rs | 45 +++++++++++++++++++++++ src/query/boolean_query/boolean_scorer.rs | 24 ++++++++++++ src/query/boolean_query/boolean_weight.rs | 33 +++++++++++++++++ src/query/boolean_query/mod.rs | 7 ++++ src/query/daat_multiterm_scorer.rs | 39 ++------------------ src/query/empty_scorer.rs | 22 +++++++++++ src/query/mod.rs | 11 ++++++ src/query/occur_filter.rs | 34 +++++++++++++++++ src/query/term_query/mod.rs | 5 +++ src/query/term_query/term_query.rs | 32 ++++++++++++++++ src/query/term_query/term_scorer.rs | 28 ++++++++++++++ src/query/term_query/term_weight.rs | 32 ++++++++++++++++ 13 files changed, 295 insertions(+), 35 deletions(-) create mode 100644 src/query/boolean_query/boolean_clause.rs create mode 100644 src/query/boolean_query/boolean_query.rs create mode 100644 src/query/boolean_query/boolean_scorer.rs create mode 100644 src/query/boolean_query/boolean_weight.rs create mode 100644 src/query/boolean_query/mod.rs create mode 100644 src/query/empty_scorer.rs create mode 100644 src/query/occur_filter.rs create mode 100644 src/query/term_query/mod.rs create mode 100644 src/query/term_query/term_query.rs create mode 100644 src/query/term_query/term_scorer.rs create mode 100644 src/query/term_query/term_weight.rs diff --git a/src/query/boolean_query/boolean_clause.rs b/src/query/boolean_query/boolean_clause.rs new file mode 100644 index 000000000..34f49f0b7 --- /dev/null +++ b/src/query/boolean_query/boolean_clause.rs @@ -0,0 +1,18 @@ +use query::Occur; +use query::Query; + +#[derive(Debug)] +pub struct BooleanClause { + pub query: Box, + pub occur: Occur, +} + + +impl BooleanClause { + pub fn new(query: Box, occur: Occur) -> BooleanClause { + BooleanClause { + query: query, + occur: occur + } + } +} \ No newline at end of file diff --git a/src/query/boolean_query/boolean_query.rs b/src/query/boolean_query/boolean_query.rs new file mode 100644 index 000000000..2e9c24bcb --- /dev/null +++ b/src/query/boolean_query/boolean_query.rs @@ -0,0 +1,45 @@ +use Result; +use std::any::Any; +use super::boolean_weight::BooleanWeight; +use super::BooleanClause; +use query::Weight; +use Searcher; +use query::Query; +use query::Occur; +use query::OccurFilter; + + +#[derive(Debug)] +pub struct BooleanQuery { + clauses: Vec, +} + +impl From> for BooleanQuery { + fn from(clauses: Vec) -> BooleanQuery { + BooleanQuery { + clauses: clauses, + } + } +} + +impl Query for BooleanQuery { + + fn as_any(&self) -> &Any { + self + } + + fn weight(&self, searcher: &Searcher) -> Result> { + let sub_weights = try!(self.clauses + .iter() + .map(|clause| clause.query.weight(searcher)) + .collect() + ); + let occurs: Vec = self.clauses + .iter() + .map(|clause| clause.occur) + .collect(); + let filter = OccurFilter::new(&occurs); + Ok(box BooleanWeight::new(sub_weights, filter)) + } + +} \ No newline at end of file diff --git a/src/query/boolean_query/boolean_scorer.rs b/src/query/boolean_query/boolean_scorer.rs new file mode 100644 index 000000000..38a624c60 --- /dev/null +++ b/src/query/boolean_query/boolean_scorer.rs @@ -0,0 +1,24 @@ +use query::Scorer; +use DocId; +use postings::DocSet; + +pub struct BooleanScorer { +} + + +impl DocSet for BooleanScorer { + fn advance(&mut self,) -> bool { + panic!("a"); + } + + fn doc(&self,) -> DocId { + panic!("a"); + } +} + +impl Scorer for BooleanScorer { + + fn score(&self,) -> f32 { + panic!(""); + } +} \ No newline at end of file diff --git a/src/query/boolean_query/boolean_weight.rs b/src/query/boolean_query/boolean_weight.rs new file mode 100644 index 000000000..550240311 --- /dev/null +++ b/src/query/boolean_query/boolean_weight.rs @@ -0,0 +1,33 @@ +use query::Weight; +use core::SegmentReader; +use query::Scorer; +use query::OccurFilter; +use Result; + +pub struct BooleanWeight { + weights: Vec>, + filter: OccurFilter, +} + +impl BooleanWeight { + pub fn new(weights: Vec>, + filter: OccurFilter) -> BooleanWeight { + BooleanWeight { + weights: weights, + filter: filter, + } + } +} + + +impl Weight for BooleanWeight { + + fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result> { + // BooleanScorer { + + // } + panic!(""); + + } + +} \ No newline at end of file diff --git a/src/query/boolean_query/mod.rs b/src/query/boolean_query/mod.rs new file mode 100644 index 000000000..bae095a88 --- /dev/null +++ b/src/query/boolean_query/mod.rs @@ -0,0 +1,7 @@ +mod boolean_clause; +mod boolean_query; +mod boolean_scorer; +mod boolean_weight; + +pub use self::boolean_query::BooleanQuery; +pub use self::boolean_clause::BooleanClause; diff --git a/src/query/daat_multiterm_scorer.rs b/src/query/daat_multiterm_scorer.rs index b0ab1920c..e0b86b5f8 100644 --- a/src/query/daat_multiterm_scorer.rs +++ b/src/query/daat_multiterm_scorer.rs @@ -7,6 +7,7 @@ use query::Similarity; use fastfield::U32FastFieldReader; use query::Occur; use std::iter; +use query::OccurFilter; use super::Scorer; use Score; @@ -35,38 +36,6 @@ impl Ord for HeapItem { } } -struct Filter { - and_mask: u64, - result: u64, -} - -impl Filter { - fn accept(&self, ord_set: u64) -> bool { - (self.and_mask & ord_set) == self.result - } - - fn new(occurs: &[Occur]) -> Filter { - let mut and_mask = 0u64; - let mut result = 0u64; - for (i, occur) in occurs.iter().enumerate() { - let shift = 1 << i; - match *occur { - Occur::Must => { - and_mask |= shift; - result |= shift; - }, - Occur::MustNot => { - and_mask |= shift; - }, - Occur::Should => {}, - } - } - Filter { - and_mask: and_mask, - result: result - } - } -} /// Document-At-A-Time multi term scorer. /// @@ -79,7 +48,7 @@ pub struct DAATMultiTermScorer, doc: DocId, similarity: TAccumulator, - filter: Filter, + filter: OccurFilter, } impl DAATMultiTermScorer { @@ -89,7 +58,7 @@ impl DAATMultiTermScore fieldnorm_readers: Vec, postings: Vec, similarity: TAccumulator, - filter: Filter + filter: OccurFilter ) -> DAATMultiTermScorer { let mut term_frequencies: Vec = iter::repeat(0u32).take(postings.len()).collect(); let heap_items: Vec = postings @@ -129,7 +98,7 @@ impl DAATMultiTermScore occurs.push(occur); } } - let filter = Filter::new(&occurs); + let filter = OccurFilter::new(&occurs); DAATMultiTermScorer::new_non_empty(fieldnorm_readers, postings, similarity, filter) } diff --git a/src/query/empty_scorer.rs b/src/query/empty_scorer.rs new file mode 100644 index 000000000..0c1e989cf --- /dev/null +++ b/src/query/empty_scorer.rs @@ -0,0 +1,22 @@ +use query::Scorer; +use DocSet; +use Score; +use DocId; + +pub struct EmptyScorer; + +impl Scorer for EmptyScorer { + fn score(&self) -> Score { + 0f32 + } +} + +impl DocSet for EmptyScorer { + fn advance(&mut self) -> bool { + false + } + + fn doc(&self) -> DocId { + 0 + } +} diff --git a/src/query/mod.rs b/src/query/mod.rs index 7a516d61f..161f30425 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -4,6 +4,7 @@ /// mod query; +mod boolean_query; mod multi_term_query; mod multi_term_accumulator; mod similarity_explainer; @@ -15,13 +16,23 @@ mod occur; mod daat_multiterm_scorer; mod similarity; mod weight; +mod occur_filter; +mod term_query; +mod empty_scorer; + + +pub use self::empty_scorer::EmptyScorer; + +pub use self::occur_filter::OccurFilter; pub use self::similarity::Similarity; pub use self::daat_multiterm_scorer::DAATMultiTermScorer; +pub use self::boolean_query::BooleanQuery; pub use self::occur::Occur; pub use self::query::Query; +pub use self::term_query::TermQuery; pub use self::multi_term_query::MultiTermQuery; pub use self::similarity_explainer::SimilarityExplainer; pub use self::tfidf::TfIdf; diff --git a/src/query/occur_filter.rs b/src/query/occur_filter.rs new file mode 100644 index 000000000..914c8f42b --- /dev/null +++ b/src/query/occur_filter.rs @@ -0,0 +1,34 @@ +use query::Occur; + +pub struct OccurFilter { + and_mask: u64, + result: u64, +} + +impl OccurFilter { + pub fn accept(&self, ord_set: u64) -> bool { + (self.and_mask & ord_set) == self.result + } + + pub fn new(occurs: &[Occur]) -> OccurFilter { + let mut and_mask = 0u64; + let mut result = 0u64; + for (i, occur) in occurs.iter().enumerate() { + let shift = 1 << i; + match *occur { + Occur::Must => { + and_mask |= shift; + result |= shift; + }, + Occur::MustNot => { + and_mask |= shift; + }, + Occur::Should => {}, + } + } + OccurFilter { + and_mask: and_mask, + result: result + } + } +} diff --git a/src/query/term_query/mod.rs b/src/query/term_query/mod.rs new file mode 100644 index 000000000..db8c5ff1d --- /dev/null +++ b/src/query/term_query/mod.rs @@ -0,0 +1,5 @@ +mod term_query; +mod term_weight; +mod term_scorer; + +pub use self::term_query::TermQuery; \ No newline at end of file diff --git a/src/query/term_query/term_query.rs b/src/query/term_query/term_query.rs new file mode 100644 index 000000000..fe18f4e16 --- /dev/null +++ b/src/query/term_query/term_query.rs @@ -0,0 +1,32 @@ +use Term; +use Result; +use super::term_weight::TermWeight; +use query::Query; +use query::Weight; +use Searcher; +use std::any::Any; + +#[derive(Debug)] +pub struct TermQuery { + term: Term, +} + +impl From for TermQuery { + fn from(term: Term) -> TermQuery { + TermQuery { + term: term + } + } +} + +impl Query for TermQuery { + fn as_any(&self) -> &Any { + self + } + + fn weight(&self, _searcher: &Searcher) -> Result> { + Ok(box TermWeight { + term: self.term.clone() + }) + } +} diff --git a/src/query/term_query/term_scorer.rs b/src/query/term_query/term_scorer.rs new file mode 100644 index 000000000..dfe73e85f --- /dev/null +++ b/src/query/term_query/term_scorer.rs @@ -0,0 +1,28 @@ +use Score; +use DocId; +use postings::SegmentPostings; +use fastfield::U32FastFieldReader; +use postings::DocSet; +use query::Scorer; + +pub struct TermScorer<'a> { + pub fieldnorm_reader: U32FastFieldReader, + pub segment_postings: SegmentPostings<'a>, +} + +impl<'a> DocSet for TermScorer<'a> { + + fn advance(&mut self,) -> bool { + self.segment_postings.advance() + } + + fn doc(&self,) -> DocId { + self.segment_postings.doc() + } +} + +impl<'a> Scorer for TermScorer<'a> { + fn score(&self,) -> Score { + 1.0 + } +} diff --git a/src/query/term_query/term_weight.rs b/src/query/term_query/term_weight.rs new file mode 100644 index 000000000..767bcc092 --- /dev/null +++ b/src/query/term_query/term_weight.rs @@ -0,0 +1,32 @@ +use Term; +use query::Weight; +use core::SegmentReader; +use query::Scorer; +use query::EmptyScorer; +use postings::SegmentPostingsOption; +use super::term_scorer::TermScorer; +use Result; + +pub struct TermWeight { + pub term: Term +} + + +impl Weight for TermWeight { + + fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result> { + let field = self.term.field(); + let fieldnorm_reader = try!(reader.get_fieldnorms_reader(field)); + if let Some(segment_postings) = reader.read_postings(&self.term, SegmentPostingsOption::Freq) { + let scorer: TermScorer = TermScorer { + fieldnorm_reader: fieldnorm_reader, + segment_postings: segment_postings, + }; + Ok(box scorer) + } + else { + Ok(box EmptyScorer) + } + } + +} \ No newline at end of file From 893932dff85789f9633dc7ec08e9e6ccaa334484 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Sun, 30 Oct 2016 17:03:37 +0900 Subject: [PATCH 07/19] issue/50 Implementation ooleanScorer. --- src/query/boolean_query/boolean_scorer.rs | 169 +++++++++++++++++++++- 1 file changed, 164 insertions(+), 5 deletions(-) diff --git a/src/query/boolean_query/boolean_scorer.rs b/src/query/boolean_query/boolean_scorer.rs index 38a624c60..b445cd75f 100644 --- a/src/query/boolean_query/boolean_scorer.rs +++ b/src/query/boolean_query/boolean_scorer.rs @@ -1,14 +1,172 @@ use query::Scorer; use DocId; +use Score; +use std::collections::BinaryHeap; +use std::cmp::Ordering; use postings::DocSet; +use query::OccurFilter; -pub struct BooleanScorer { + +struct ScoreCombiner { + coords: Vec, + num_fields: usize, + score: Score, +} + +impl ScoreCombiner { + + fn update(&mut self, score: Score) { + self.score += score; + self.num_fields += 1; + } + + fn clear(&mut self,) { + self.score = 0f32; + self.num_fields = 0; + } + + /// Compute the coord term + fn coord(&self,) -> f32 { + self.coords[self.num_fields] + } + + #[inline] + fn score(&self, ) -> Score { + self.score * self.coord() + } +} + +impl From> for ScoreCombiner { + fn from(coords: Vec) -> ScoreCombiner { + ScoreCombiner { + coords: coords, + num_fields: 0, + score: 0f32, + } + } } -impl DocSet for BooleanScorer { +/// Each `HeapItem` represents the head of +/// a segment postings being merged. +/// +/// * `doc` - is the current doc id for the given segment postings +/// * `ord` - is the ordinal used to identify to which segment postings +/// this heap item belong to. +#[derive(Eq, PartialEq)] +struct HeapItem { + doc: DocId, + ord: u32, +} + +/// `HeapItem` are ordered by the document +impl PartialOrd for HeapItem { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for HeapItem { + fn cmp(&self, other:&Self) -> Ordering { + (other.doc).cmp(&self.doc) + } +} + +pub struct BooleanScorer { + postings: Vec, + queue: BinaryHeap, + doc: DocId, + score_combiner: ScoreCombiner, + filter: OccurFilter, +} + +impl BooleanScorer { + + fn new(postings: Vec, filter: OccurFilter) -> BooleanScorer { + let num_postings = postings.len(); + let query_coords: Vec = (0..num_postings + 1) + .map(|i| (i as Score) / (num_postings as Score)) + .collect(); + let score_combiner = ScoreCombiner::from(query_coords); + let heap_items: Vec = postings + .iter() + .map(|posting| posting.doc()) + .enumerate() + .map(|(ord, doc)| { + HeapItem { + doc: doc, + ord: ord as u32 + } + }) + .collect(); + BooleanScorer { + postings: postings, + queue: BinaryHeap::from(heap_items), + doc: 0u32, + score_combiner: score_combiner, + filter: filter, + + } + } + + + /// Advances the head of our heap (the segment postings with the lowest doc) + /// It will also update the new current `DocId` as well as the term frequency + /// associated with the segment postings. + /// + /// After advancing the `SegmentPosting`, the postings is removed from the heap + /// if it has been entirely consumed, or pushed back into the heap. + /// + /// # Panics + /// This method will panic if the head `SegmentPostings` is not empty. + fn advance_head(&mut self,) { + { + let mut mutable_head = self.queue.peek_mut().unwrap(); + let cur_postings = &mut self.postings[mutable_head.ord as usize]; + if cur_postings.advance() { + mutable_head.doc = cur_postings.doc(); + return; + } + + } + self.queue.pop(); + } +} + +impl DocSet for BooleanScorer { fn advance(&mut self,) -> bool { - panic!("a"); + loop { + self.score_combiner.clear(); + let mut ord_bitset = 0u64; + match self.queue.peek() { + Some(heap_item) => { + let ord = heap_item.ord as usize; + self.doc = heap_item.doc; + let score = self.postings[ord].score(); + self.score_combiner.update(score); + ord_bitset |= 1 << ord; + } + None => { + return false; + } + } + self.advance_head(); + while let Some(&HeapItem {doc, ord}) = self.queue.peek() { + if doc == self.doc { + let ord = ord as usize; + let score = self.postings[ord].score(); + self.score_combiner.update(score); + ord_bitset |= 1 << ord; + } + else { + break; + } + self.advance_head(); + } + if self.filter.accept(ord_bitset) { + return true; + } + } } fn doc(&self,) -> DocId { @@ -16,9 +174,10 @@ impl DocSet for BooleanScorer { } } -impl Scorer for BooleanScorer { +impl Scorer for BooleanScorer { fn score(&self,) -> f32 { panic!(""); } -} \ No newline at end of file +} + From 7421e0a48de3b88dc3e6a260ea2d08c933f052c2 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Sun, 30 Oct 2016 18:47:24 +0900 Subject: [PATCH 08/19] issue/50 going on --- src/query/daat_multiterm_scorer.rs | 1 - src/query/multi_term_query.rs | 2 +- src/query/scorer.rs | 1 + src/query/term_query/term_query.rs | 4 +++- src/query/term_query/term_scorer.rs | 6 +++++- src/query/term_query/term_weight.rs | 5 ++++- 6 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/query/daat_multiterm_scorer.rs b/src/query/daat_multiterm_scorer.rs index e0b86b5f8..98f51f49f 100644 --- a/src/query/daat_multiterm_scorer.rs +++ b/src/query/daat_multiterm_scorer.rs @@ -117,7 +117,6 @@ impl DAATMultiTermScore /// # Panics /// This method will panic if the head `SegmentPostings` is not empty. fn advance_head(&mut self,) { - { let mut mutable_head = self.queue.peek_mut().unwrap(); let cur_postings = &mut self.postings[mutable_head.ord as usize]; diff --git a/src/query/multi_term_query.rs b/src/query/multi_term_query.rs index 15d63bfdf..2f4340b29 100644 --- a/src/query/multi_term_query.rs +++ b/src/query/multi_term_query.rs @@ -22,9 +22,9 @@ struct MultiTermWeight { impl Weight for MultiTermWeight { - fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result> { + let mut postings_and_fieldnorms = Vec::with_capacity(self.query.num_terms()); { for &(occur, ref term) in &self.query.occur_terms { diff --git a/src/query/scorer.rs b/src/query/scorer.rs index fc6525800..32ed4b443 100644 --- a/src/query/scorer.rs +++ b/src/query/scorer.rs @@ -15,3 +15,4 @@ pub trait Scorer: DocSet { } } } + diff --git a/src/query/term_query/term_query.rs b/src/query/term_query/term_query.rs index fe18f4e16..e763a36ae 100644 --- a/src/query/term_query/term_query.rs +++ b/src/query/term_query/term_query.rs @@ -24,8 +24,10 @@ impl Query for TermQuery { self } - fn weight(&self, _searcher: &Searcher) -> Result> { + fn weight(&self, searcher: &Searcher) -> Result> { + let doc_freq = searcher.doc_freq(&self.term); Ok(box TermWeight { + doc_freq: doc_freq, term: self.term.clone() }) } diff --git a/src/query/term_query/term_scorer.rs b/src/query/term_query/term_scorer.rs index dfe73e85f..6ad5c591d 100644 --- a/src/query/term_query/term_scorer.rs +++ b/src/query/term_query/term_scorer.rs @@ -4,8 +4,10 @@ use postings::SegmentPostings; use fastfield::U32FastFieldReader; use postings::DocSet; use query::Scorer; +use postings::Postings; pub struct TermScorer<'a> { + pub idf: Score, pub fieldnorm_reader: U32FastFieldReader, pub segment_postings: SegmentPostings<'a>, } @@ -23,6 +25,8 @@ impl<'a> DocSet for TermScorer<'a> { impl<'a> Scorer for TermScorer<'a> { fn score(&self,) -> Score { - 1.0 + let doc = self.segment_postings.doc(); + let field_norm = self.fieldnorm_reader.get(doc); + self.idf * (self.segment_postings.term_freq() as f32 / field_norm as f32).sqrt() } } diff --git a/src/query/term_query/term_weight.rs b/src/query/term_query/term_weight.rs index 767bcc092..b00f1ce16 100644 --- a/src/query/term_query/term_weight.rs +++ b/src/query/term_query/term_weight.rs @@ -1,4 +1,5 @@ use Term; +use Score; use query::Weight; use core::SegmentReader; use query::Scorer; @@ -8,7 +9,8 @@ use super::term_scorer::TermScorer; use Result; pub struct TermWeight { - pub term: Term + pub doc_freq: u32, + pub term: Term, } @@ -19,6 +21,7 @@ impl Weight for TermWeight { let fieldnorm_reader = try!(reader.get_fieldnorms_reader(field)); if let Some(segment_postings) = reader.read_postings(&self.term, SegmentPostingsOption::Freq) { let scorer: TermScorer = TermScorer { + idf: 1f32 / (self.doc_freq as f32), fieldnorm_reader: fieldnorm_reader, segment_postings: segment_postings, }; From 5f96823e80c3efab59982a49570ffea1f85ffb58 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Mon, 31 Oct 2016 10:12:58 +0900 Subject: [PATCH 09/19] issue/50 Switched to BooleanQueryScorer --- src/query/boolean_query/boolean_scorer.rs | 17 ++- src/query/boolean_query/boolean_weight.rs | 6 +- src/query/boolean_query/mod.rs | 1 + src/query/multi_term_query.rs | 124 +++++++++------------- src/query/occur_filter.rs | 1 + src/query/query.rs | 1 - src/query/term_query/mod.rs | 4 +- src/query/term_query/term_query.rs | 17 ++- src/query/term_query/term_weight.rs | 40 ++++--- 9 files changed, 107 insertions(+), 104 deletions(-) diff --git a/src/query/boolean_query/boolean_scorer.rs b/src/query/boolean_query/boolean_scorer.rs index b445cd75f..559640230 100644 --- a/src/query/boolean_query/boolean_scorer.rs +++ b/src/query/boolean_query/boolean_scorer.rs @@ -82,13 +82,20 @@ pub struct BooleanScorer { impl BooleanScorer { - fn new(postings: Vec, filter: OccurFilter) -> BooleanScorer { + pub fn new(postings: Vec, filter: OccurFilter) -> BooleanScorer { let num_postings = postings.len(); let query_coords: Vec = (0..num_postings + 1) .map(|i| (i as Score) / (num_postings as Score)) .collect(); let score_combiner = ScoreCombiner::from(query_coords); - let heap_items: Vec = postings + let mut non_empty_postings: Vec = Vec::new(); + for mut posting in postings { + let non_empty = posting.advance(); + if non_empty { + non_empty_postings.push(posting); + } + } + let heap_items: Vec = non_empty_postings .iter() .map(|posting| posting.doc()) .enumerate() @@ -100,7 +107,7 @@ impl BooleanScorer { }) .collect(); BooleanScorer { - postings: postings, + postings: non_empty_postings, queue: BinaryHeap::from(heap_items), doc: 0u32, score_combiner: score_combiner, @@ -170,14 +177,14 @@ impl DocSet for BooleanScorer { } fn doc(&self,) -> DocId { - panic!("a"); + self.doc } } impl Scorer for BooleanScorer { fn score(&self,) -> f32 { - panic!(""); + self.score_combiner.score() } } diff --git a/src/query/boolean_query/boolean_weight.rs b/src/query/boolean_query/boolean_weight.rs index 550240311..ee0c37860 100644 --- a/src/query/boolean_query/boolean_weight.rs +++ b/src/query/boolean_query/boolean_weight.rs @@ -6,15 +6,15 @@ use Result; pub struct BooleanWeight { weights: Vec>, - filter: OccurFilter, + occur_filter: OccurFilter, } impl BooleanWeight { pub fn new(weights: Vec>, - filter: OccurFilter) -> BooleanWeight { + occur_filter: OccurFilter) -> BooleanWeight { BooleanWeight { weights: weights, - filter: filter, + occur_filter: occur_filter, } } } diff --git a/src/query/boolean_query/mod.rs b/src/query/boolean_query/mod.rs index bae095a88..9e7506c2e 100644 --- a/src/query/boolean_query/mod.rs +++ b/src/query/boolean_query/mod.rs @@ -5,3 +5,4 @@ mod boolean_weight; pub use self::boolean_query::BooleanQuery; pub use self::boolean_clause::BooleanClause; +pub use self::boolean_scorer::BooleanScorer; \ No newline at end of file diff --git a/src/query/multi_term_query.rs b/src/query/multi_term_query.rs index 2f4340b29..511489961 100644 --- a/src/query/multi_term_query.rs +++ b/src/query/multi_term_query.rs @@ -1,52 +1,43 @@ use Result; use super::Weight; use std::any::Any; -use Error; use schema::Term; use query::Query; use core::searcher::Searcher; use core::SegmentReader; -use query::TfIdf; use query::Scorer; use query::occur::Occur; -use postings::SegmentPostingsOption; -use query::DAATMultiTermScorer; - +use query::occur_filter::OccurFilter; +use query::term_query::{TermQuery, TermWeight, TermScorer}; +use query::boolean_query::BooleanScorer; struct MultiTermWeight { - query: MultiTermQuery, - similitude: TfIdf, + weights: Vec, + occur_filter: OccurFilter, } impl Weight for MultiTermWeight { fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result> { - - - let mut postings_and_fieldnorms = Vec::with_capacity(self.query.num_terms()); - { - for &(occur, ref term) in &self.query.occur_terms { - if let Some(postings) = reader.read_postings(term, SegmentPostingsOption::Freq) { - let field = term.field(); - let fieldnorm_reader = try!(reader.get_fieldnorms_reader(field)); - postings_and_fieldnorms.push((occur, postings, fieldnorm_reader)); - } + let mut term_scorers: Vec> = Vec::new(); + for term_weight in &self.weights { + let term_scorer_option = try!(term_weight.specialized_scorer(reader)); + if let Some(term_scorer) = term_scorer_option { + term_scorers.push(term_scorer); } } - if postings_and_fieldnorms.len() > 64 { - // TODO putting the SHOULD at the end of the list should push the limit. - return Err(Error::InvalidArgument(String::from("Limit of 64 terms was exceeded."))); - } - Ok(box DAATMultiTermScorer::new(postings_and_fieldnorms, self.similitude.clone())) + Ok(box BooleanScorer::new(term_scorers, self.occur_filter.clone())) } } /// Query involving one or more terms. + #[derive(Eq, Clone, PartialEq, Debug)] -pub struct MultiTermQuery { - occur_terms: Vec<(Occur, Term)>, +pub struct MultiTermQuery { + // TODO need a better Debug + occur_terms: Vec<(Occur, Term)> } impl MultiTermQuery { @@ -55,57 +46,10 @@ impl MultiTermQuery { pub fn num_terms(&self,) -> usize { self.occur_terms.len() } - - /// Builds the similitude object - fn similitude(&self, searcher: &Searcher) -> TfIdf { - let num_terms = self.num_terms(); - let num_docs = searcher.num_docs() as f32; - let idfs: Vec = self.occur_terms - .iter() - .map(|&(_, ref term)| searcher.doc_freq(term)) - .map(|doc_freq| { - if doc_freq == 0 { - 1. - } - else { - 1. + ( num_docs / (doc_freq as f32) ).ln() - } - }) - .collect(); - let query_coords = (0..num_terms + 1) - .map(|i| (i as f32) / (num_terms as f32)) - .collect(); - // TODO have the actual terms in these names - let term_names = self.occur_terms - .iter() - .map(|&(_, ref term)| format!("{:?}", &term)) - .collect(); - let mut tfidf = TfIdf::new(query_coords, idfs); - tfidf.set_term_names(term_names); - tfidf - } + } -impl From> for MultiTermQuery { - fn from(occur_terms: Vec<(Occur, Term)>) -> MultiTermQuery { - MultiTermQuery { - occur_terms: occur_terms, - } - } -} - -impl From> for MultiTermQuery { - fn from(terms: Vec) -> MultiTermQuery { - let should_terms = terms - .into_iter() - .map(|term| (Occur::Should, term)) - .collect(); - MultiTermQuery { - occur_terms: should_terms, - } - } -} impl Query for MultiTermQuery { @@ -114,14 +58,42 @@ impl Query for MultiTermQuery { } fn weight(&self, searcher: &Searcher) -> Result> { - let similitude = self.similitude(searcher); + let term_queries: Vec = self.occur_terms + .iter() + .map(|&(_, ref term)| TermQuery::from(term.clone())) + .collect(); + let occurs: Vec = self.occur_terms + .iter() + .map(|&(occur, _) | occur.clone()) + .collect(); + let occur_filter = OccurFilter::new(&occurs); + let weights = term_queries.iter() + .map(|term_query| term_query.specialized_weight(searcher)) + .collect(); Ok( Box::new(MultiTermWeight { - query: self.clone(), - similitude: similitude + weights: weights, + occur_filter: occur_filter, }) ) } - } + +impl From> for MultiTermQuery { + fn from(occur_terms: Vec<(Occur, Term)>) -> MultiTermQuery { + MultiTermQuery { + occur_terms: occur_terms + } + } +} + +impl From> for MultiTermQuery { + fn from(terms: Vec) -> MultiTermQuery { + let should_terms: Vec<(Occur, Term)> = terms + .into_iter() + .map(|term| (Occur::Should, term)) + .collect(); + MultiTermQuery::from(should_terms) + } +} \ No newline at end of file diff --git a/src/query/occur_filter.rs b/src/query/occur_filter.rs index 914c8f42b..188fc637f 100644 --- a/src/query/occur_filter.rs +++ b/src/query/occur_filter.rs @@ -1,5 +1,6 @@ use query::Occur; +#[derive(Clone)] pub struct OccurFilter { and_mask: u64, result: u64, diff --git a/src/query/query.rs b/src/query/query.rs index 3b15f3010..f4bcfa3c2 100644 --- a/src/query/query.rs +++ b/src/query/query.rs @@ -82,5 +82,4 @@ pub trait Query: fmt::Debug { } Ok(timer_tree) } - } diff --git a/src/query/term_query/mod.rs b/src/query/term_query/mod.rs index db8c5ff1d..e8be286c1 100644 --- a/src/query/term_query/mod.rs +++ b/src/query/term_query/mod.rs @@ -2,4 +2,6 @@ mod term_query; mod term_weight; mod term_scorer; -pub use self::term_query::TermQuery; \ No newline at end of file +pub use self::term_query::TermQuery; +pub use self::term_weight::TermWeight; +pub use self::term_scorer::TermScorer; diff --git a/src/query/term_query/term_query.rs b/src/query/term_query/term_query.rs index e763a36ae..6049d8455 100644 --- a/src/query/term_query/term_query.rs +++ b/src/query/term_query/term_query.rs @@ -11,6 +11,16 @@ pub struct TermQuery { term: Term, } +impl TermQuery { + pub fn specialized_weight(&self, searcher: &Searcher) -> TermWeight { + let doc_freq = searcher.doc_freq(&self.term); + TermWeight { + doc_freq: doc_freq, + term: self.term.clone() + } + } +} + impl From for TermQuery { fn from(term: Term) -> TermQuery { TermQuery { @@ -25,10 +35,7 @@ impl Query for TermQuery { } fn weight(&self, searcher: &Searcher) -> Result> { - let doc_freq = searcher.doc_freq(&self.term); - Ok(box TermWeight { - doc_freq: doc_freq, - term: self.term.clone() - }) + Ok(box self.specialized_weight(searcher)) } + } diff --git a/src/query/term_query/term_weight.rs b/src/query/term_query/term_weight.rs index b00f1ce16..acd5ffe83 100644 --- a/src/query/term_query/term_weight.rs +++ b/src/query/term_query/term_weight.rs @@ -1,5 +1,4 @@ use Term; -use Score; use query::Weight; use core::SegmentReader; use query::Scorer; @@ -17,19 +16,34 @@ pub struct TermWeight { impl Weight for TermWeight { fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result> { - let field = self.term.field(); - let fieldnorm_reader = try!(reader.get_fieldnorms_reader(field)); - if let Some(segment_postings) = reader.read_postings(&self.term, SegmentPostingsOption::Freq) { - let scorer: TermScorer = TermScorer { - idf: 1f32 / (self.doc_freq as f32), - fieldnorm_reader: fieldnorm_reader, - segment_postings: segment_postings, - }; - Ok(box scorer) - } - else { - Ok(box EmptyScorer) + let specialized_scorer_option = try!(self.specialized_scorer(reader)); + match specialized_scorer_option { + Some(term_scorer) => { + Ok(box term_scorer) + } + None => { + Ok(box EmptyScorer) + } } } +} + +impl TermWeight { + + pub fn specialized_scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result>> { + let field = self.term.field(); + let fieldnorm_reader = try!(reader.get_fieldnorms_reader(field)); + Ok( + reader.read_postings(&self.term, SegmentPostingsOption::Freq) + .map(|segment_postings| + TermScorer { + idf: 1f32 / (self.doc_freq as f32), + fieldnorm_reader: fieldnorm_reader, + segment_postings: segment_postings, + } + ) + ) + } + } \ No newline at end of file From 249759c8785e9ffbccabe12b4373bbe82b142fea Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Tue, 1 Nov 2016 11:42:29 +0900 Subject: [PATCH 10/19] issue/50 Test broken --- src/fastfield/mod.rs | 15 +- src/fastfield/reader.rs | 33 ++- src/postings/mod.rs | 2 + src/postings/segment_postings.rs | 1 - src/postings/segment_postings_tester.rs | 76 +++++++ src/postings/serializer.rs | 27 ++- src/query/boolean_query/boolean_scorer.rs | 138 ++++++++---- src/query/boolean_query/mod.rs | 4 +- src/query/boolean_query/score_combiner.rs | 46 ++++ src/query/daat_multiterm_scorer.rs | 253 ---------------------- src/query/mod.rs | 4 - src/query/term_query/term_scorer.rs | 3 +- 12 files changed, 284 insertions(+), 318 deletions(-) create mode 100644 src/postings/segment_postings_tester.rs create mode 100644 src/query/boolean_query/score_combiner.rs delete mode 100644 src/query/daat_multiterm_scorer.rs diff --git a/src/fastfield/mod.rs b/src/fastfield/mod.rs index d1e1c8b6d..39d68ebfb 100644 --- a/src/fastfield/mod.rs +++ b/src/fastfield/mod.rs @@ -39,9 +39,7 @@ fn compute_num_bits(amplitude: u32) -> u8 { mod tests { use super::compute_num_bits; - use super::U32FastFieldsReader; - use super::U32FastFieldsWriter; - use super::FastFieldSerializer; + use super::*; use schema::Field; use std::path::Path; use directory::{Directory, WritePtr, RAMDirectory}; @@ -81,6 +79,17 @@ mod tests { doc.add_u32(field, value); fast_field_writers.add_document(&doc); } + + #[test] + pub fn test_fastfield() { + let test_fastfield = U32FastFieldReader::from(vec!(100,200,300)); + println!("{}", test_fastfield.get(0)); + println!("{}", test_fastfield.get(1)); + println!("{}", test_fastfield.get(2)); + assert_eq!(test_fastfield.get(0), 100); + assert_eq!(test_fastfield.get(1), 200); + assert_eq!(test_fastfield.get(2), 300); + } #[test] fn test_intfastfield_small() { diff --git a/src/fastfield/reader.rs b/src/fastfield/reader.rs index f57aa9362..18046025a 100644 --- a/src/fastfield/reader.rs +++ b/src/fastfield/reader.rs @@ -5,8 +5,12 @@ use std::ops::Deref; use directory::ReadOnlySource; use common::BinarySerializable; use DocId; -use schema::Field; - +use schema::{Field, SchemaBuilder}; +use std::path::Path; +use schema::FAST; +use directory::{WritePtr, RAMDirectory, Directory}; +use fastfield::FastFieldSerializer; +use fastfield::U32FastFieldsWriter; use super::compute_num_bits; pub struct U32FastFieldReader { @@ -62,6 +66,31 @@ impl U32FastFieldReader { } } + +impl From> for U32FastFieldReader { + fn from(vals: Vec) -> U32FastFieldReader { + let mut schema_builder = SchemaBuilder::default(); + let field = schema_builder.add_u32_field("field", FAST); + let schema = schema_builder.build(); + let path = Path::new("test"); + let mut directory: RAMDirectory = RAMDirectory::create(); + { + let write: WritePtr = directory.open_write(Path::new("test")).unwrap(); + let mut serializer = FastFieldSerializer::new(write).unwrap(); + let mut fast_field_writers = U32FastFieldsWriter::from_schema(&schema); + for val in vals { + let mut fast_field_writer = fast_field_writers.get_field_writer(field).unwrap(); + fast_field_writer.add_val(val); + } + fast_field_writers.serialize(&mut serializer).unwrap(); + serializer.close().unwrap(); + } + let source = directory.open_read(&path).unwrap(); + let fast_field_readers = U32FastFieldsReader::open(source).unwrap(); + fast_field_readers.get_field(field).unwrap() + } +} + pub struct U32FastFieldsReader { source: ReadOnlySource, field_offsets: HashMap, diff --git a/src/postings/mod.rs b/src/postings/mod.rs index 374f08c33..2227353de 100644 --- a/src/postings/mod.rs +++ b/src/postings/mod.rs @@ -17,7 +17,9 @@ mod offset_postings; mod freq_handler; mod docset; mod segment_postings_option; +mod segment_postings_tester; +pub use self::segment_postings_tester::SegmentPostingsTester; pub use self::docset::{SkipResult, DocSet}; pub use self::offset_postings::OffsetPostings; pub use self::recorder::{Recorder, NothingRecorder, TermFrequencyRecorder, TFAndPositionRecorder}; diff --git a/src/postings/segment_postings.rs b/src/postings/segment_postings.rs index 8de872f90..ab12add7d 100644 --- a/src/postings/segment_postings.rs +++ b/src/postings/segment_postings.rs @@ -62,7 +62,6 @@ impl<'a> SegmentPostings<'a> { impl<'a> DocSet for SegmentPostings<'a> { - // goes to the next element. // next needs to be called a first time to point to the correct element. #[inline] diff --git a/src/postings/segment_postings_tester.rs b/src/postings/segment_postings_tester.rs new file mode 100644 index 000000000..b34e69003 --- /dev/null +++ b/src/postings/segment_postings_tester.rs @@ -0,0 +1,76 @@ +use super::FreqHandler; +use DocId; +use std::path::Path; +use super::SegmentPostings; +use super::serializer::PostingsSerializer; +use schema::{SchemaBuilder, STRING}; +use directory::{RAMDirectory, Directory}; +use schema::Term; + + +const EMPTY_POSITIONS: [DocId; 0] = [0u32; 0]; + +pub struct SegmentPostingsTester { + data: Vec, + len: u32, +} + +impl SegmentPostingsTester { + pub fn get(&self) -> SegmentPostings { + SegmentPostings::from_data(self.len, &self.data, FreqHandler::new_without_freq()) + } +} + +impl From> for SegmentPostingsTester { + + fn from(doc_ids: Vec) -> SegmentPostingsTester { + let mut directory = RAMDirectory::create(); + let mut schema_builder = SchemaBuilder::default(); + let field = schema_builder.add_text_field("text", STRING); + let schema = schema_builder.build(); + let mut postings_serializer = PostingsSerializer::new( + directory.open_write(Path::new("terms")).unwrap(), + directory.open_write(Path::new("postings")).unwrap(), + directory.open_write(Path::new("positions")).unwrap(), + schema + ).unwrap(); + let term = Term::from_field_text(field, "dummy"); + postings_serializer.new_term(&term, doc_ids.len() as u32); + for doc_id in &doc_ids { + postings_serializer.write_doc(*doc_id, 1u32, &EMPTY_POSITIONS); + } + postings_serializer.close_term(); + postings_serializer.close(); + let postings_data = directory.open_read(Path::new("postings")).unwrap(); + SegmentPostingsTester { + data: Vec::from(postings_data.as_slice()), + len: doc_ids.len() as u32, + } + } + +} + + + +#[cfg(test)] +mod tests { + + use super::*; + use postings::DocSet; + + #[test] + pub fn test_segment_postings_tester() { + let segment_postings_tester = SegmentPostingsTester::from(vec!(1,2,17,32)); + let mut postings = segment_postings_tester.get(); + assert!(postings.advance()); + assert_eq!(postings.doc(), 1); + assert!(postings.advance()); + assert_eq!(postings.doc(), 2); + assert!(postings.advance()); + assert_eq!(postings.doc(), 17); + assert!(postings.advance()); + assert_eq!(postings.doc(), 32); + assert!(!postings.advance()); + } + +} diff --git a/src/postings/serializer.rs b/src/postings/serializer.rs index 5fb7c9505..1cf1b5392 100644 --- a/src/postings/serializer.rs +++ b/src/postings/serializer.rs @@ -66,13 +66,16 @@ pub struct PostingsSerializer { impl PostingsSerializer { + + /// Open a new `PostingsSerializer` for the given segment - pub fn open(segment: &mut Segment) -> Result { - let terms_write = try!(segment.open_write(SegmentComponent::TERMS)); + pub fn new( + terms_write: WritePtr, + postings_write: WritePtr, + positions_write: WritePtr, + schema: Schema + ) -> Result { let terms_fst_builder = try!(FstMapBuilder::new(terms_write)); - let postings_write = try!(segment.open_write(SegmentComponent::POSTINGS)); - let positions_write = try!(segment.open_write(SegmentComponent::POSITIONS)); - let schema = segment.schema(); Ok(PostingsSerializer { terms_fst_builder: terms_fst_builder, postings_write: postings_write, @@ -91,6 +94,20 @@ impl PostingsSerializer { }) } + + /// Open a new `PostingsSerializer` for the given segment + pub fn open(segment: &mut Segment) -> Result { + let terms_write = try!(segment.open_write(SegmentComponent::TERMS)); + let postings_write = try!(segment.open_write(SegmentComponent::POSTINGS)); + let positions_write = try!(segment.open_write(SegmentComponent::POSITIONS)); + PostingsSerializer::new( + terms_write, + postings_write, + positions_write, + segment.schema() + ) + } + fn load_indexing_options(&mut self, field: Field) { let field_entry: &FieldEntry = self.schema.get_field_entry(field); self.text_indexing_options = match *field_entry.field_type() { diff --git a/src/query/boolean_query/boolean_scorer.rs b/src/query/boolean_query/boolean_scorer.rs index 559640230..91b31d192 100644 --- a/src/query/boolean_query/boolean_scorer.rs +++ b/src/query/boolean_query/boolean_scorer.rs @@ -5,46 +5,7 @@ use std::collections::BinaryHeap; use std::cmp::Ordering; use postings::DocSet; use query::OccurFilter; - - -struct ScoreCombiner { - coords: Vec, - num_fields: usize, - score: Score, -} - -impl ScoreCombiner { - - fn update(&mut self, score: Score) { - self.score += score; - self.num_fields += 1; - } - - fn clear(&mut self,) { - self.score = 0f32; - self.num_fields = 0; - } - - /// Compute the coord term - fn coord(&self,) -> f32 { - self.coords[self.num_fields] - } - - #[inline] - fn score(&self, ) -> Score { - self.score * self.coord() - } -} - -impl From> for ScoreCombiner { - fn from(coords: Vec) -> ScoreCombiner { - ScoreCombiner { - coords: coords, - num_fields: 0, - score: 0f32, - } - } -} +use query::boolean_query::ScoreCombiner; /// Each `HeapItem` represents the head of @@ -82,12 +43,13 @@ pub struct BooleanScorer { impl BooleanScorer { - pub fn new(postings: Vec, filter: OccurFilter) -> BooleanScorer { - let num_postings = postings.len(); - let query_coords: Vec = (0..num_postings + 1) - .map(|i| (i as Score) / (num_postings as Score)) - .collect(); - let score_combiner = ScoreCombiner::from(query_coords); + pub fn set_score_combiner(&mut self, score_combiner: ScoreCombiner) { + self.score_combiner = score_combiner; + } + + pub fn new(postings: Vec, + filter: OccurFilter) -> BooleanScorer { + let score_combiner = ScoreCombiner::default_for_num_scorers(postings.len()); let mut non_empty_postings: Vec = Vec::new(); for mut posting in postings { let non_empty = posting.advance(); @@ -131,10 +93,9 @@ impl BooleanScorer { let mut mutable_head = self.queue.peek_mut().unwrap(); let cur_postings = &mut self.postings[mutable_head.ord as usize]; if cur_postings.advance() { - mutable_head.doc = cur_postings.doc(); + mutable_head.doc = cur_postings.doc(); return; } - } self.queue.pop(); } @@ -188,3 +149,84 @@ impl Scorer for BooleanScorer { } } + + + +#[cfg(test)] +mod tests { + + use super::*; + use postings::{DocSet, VecPostings}; + use query::TfIdf; + use query::Scorer; + use query::OccurFilter; + use query::term_query::TermScorer; + use directory::Directory; + use directory::RAMDirectory; + use schema::Field; + use super::super::ScoreCombiner; + use std::path::Path; + use query::Occur; + use postings::SegmentPostingsTester; + use postings::Postings; + use fastfield::{U32FastFieldReader, U32FastFieldWriter, FastFieldSerializer}; + + + + fn abs_diff(left: f32, right: f32) -> f32 { + (right - left).abs() + } + + #[test] + pub fn test_boolean_scorer() { + let occurs = vec!(Occur::Should, Occur::Should); + let occur_filter = OccurFilter::new(&occurs); + + let left_fieldnorms = U32FastFieldReader::from(vec!(100,200,300)); + let left_tester = SegmentPostingsTester::from(vec!(1, 2, 3)); + let left = left_tester.get(); + let left_scorer = TermScorer { + idf: 1f32, + fieldnorm_reader: left_fieldnorms, + segment_postings: left, + }; + + let right_fieldnorms = U32FastFieldReader::from(vec!(15,25,35)); + let right_tester = SegmentPostingsTester::from(vec!(1, 3, 8)); + let right = right_tester.get(); + let mut right_scorer = TermScorer { + idf: 4f32, + fieldnorm_reader: right_fieldnorms, + segment_postings: right, + }; + let score_combiner = ScoreCombiner::from(vec!(0f32, 1f32, 2f32)); + let mut boolean_scorer = BooleanScorer::new(vec!(left_scorer, right_scorer), occur_filter); + boolean_scorer.set_score_combiner(score_combiner); + assert_eq!(boolean_scorer.next(), Some(1u32)); + assert!(abs_diff(boolean_scorer.score(), 1.7414213) < 0.001); + assert_eq!(boolean_scorer.next(), Some(2u32)); + assert!(abs_diff(boolean_scorer.score(), 0.057735026) < 0.001f32); + assert_eq!(boolean_scorer.next(), Some(3u32)); + assert_eq!(boolean_scorer.next(), Some(8u32)); + assert!(abs_diff(boolean_scorer.score(), 1.0327955) < 0.001f32); + assert!(!boolean_scorer.advance()); + } + + + #[test] + pub fn test_term_scorer() { + let left_fieldnorms = U32FastFieldReader::from(vec!(10, 4)); + assert_eq!(left_fieldnorms.get(0), 10); + assert_eq!(left_fieldnorms.get(1), 4); + let left_tester = SegmentPostingsTester::from(vec!(1)); + let left = left_tester.get(); + let mut left_scorer = TermScorer { + idf: 0.30685282, // 1f32, + fieldnorm_reader: left_fieldnorms, + segment_postings: left, + }; + left_scorer.advance(); + assert!(abs_diff(left_scorer.score(), 0.15342641) < 0.001f32); + } + +} diff --git a/src/query/boolean_query/mod.rs b/src/query/boolean_query/mod.rs index 9e7506c2e..3f19cb92e 100644 --- a/src/query/boolean_query/mod.rs +++ b/src/query/boolean_query/mod.rs @@ -2,7 +2,9 @@ mod boolean_clause; mod boolean_query; mod boolean_scorer; mod boolean_weight; +mod score_combiner; pub use self::boolean_query::BooleanQuery; pub use self::boolean_clause::BooleanClause; -pub use self::boolean_scorer::BooleanScorer; \ No newline at end of file +pub use self::boolean_scorer::BooleanScorer; +pub use self::score_combiner::ScoreCombiner; \ No newline at end of file diff --git a/src/query/boolean_query/score_combiner.rs b/src/query/boolean_query/score_combiner.rs new file mode 100644 index 000000000..204c57c23 --- /dev/null +++ b/src/query/boolean_query/score_combiner.rs @@ -0,0 +1,46 @@ +use Score; + +pub struct ScoreCombiner { + coords: Vec, + num_fields: usize, + score: Score, +} + +impl ScoreCombiner { + + pub fn update(&mut self, score: Score) { + self.score += score; + self.num_fields += 1; + } + + pub fn clear(&mut self,) { + self.score = 0f32; + self.num_fields = 0; + } + + /// Compute the coord term + fn coord(&self,) -> f32 { + self.coords[self.num_fields] + } + + pub fn score(&self, ) -> Score { + self.score * self.coord() + } + + pub fn default_for_num_scorers(num_scorers: usize) -> ScoreCombiner { + let query_coords: Vec = (0..num_scorers + 1) + .map(|i| (i as Score) / (num_scorers as Score)) + .collect(); + ScoreCombiner::from(query_coords) + } +} + +impl From> for ScoreCombiner { + fn from(coords: Vec) -> ScoreCombiner { + ScoreCombiner { + coords: coords, + num_fields: 0, + score: 0f32, + } + } +} \ No newline at end of file diff --git a/src/query/daat_multiterm_scorer.rs b/src/query/daat_multiterm_scorer.rs deleted file mode 100644 index 98f51f49f..000000000 --- a/src/query/daat_multiterm_scorer.rs +++ /dev/null @@ -1,253 +0,0 @@ -use DocId; -use postings::{Postings, DocSet}; -use std::cmp::Ordering; -use std::collections::BinaryHeap; -use query::MultiTermAccumulator; -use query::Similarity; -use fastfield::U32FastFieldReader; -use query::Occur; -use std::iter; -use query::OccurFilter; -use super::Scorer; -use Score; - -/// Each `HeapItem` represents the head of -/// a segment postings being merged. -/// -/// * `doc` - is the current doc id for the given segment postings -/// * `ord` - is the ordinal used to identify to which segment postings -/// this heap item belong to. -#[derive(Eq, PartialEq)] -struct HeapItem { - doc: DocId, - ord: u32, -} - -/// `HeapItem` are ordered by the document -impl PartialOrd for HeapItem { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for HeapItem { - fn cmp(&self, other:&Self) -> Ordering { - (other.doc).cmp(&self.doc) - } -} - - -/// Document-At-A-Time multi term scorer. -/// -/// The scorer merges multiple segment postings and pushes -/// term information to the score accumulator. -pub struct DAATMultiTermScorer { - fieldnorm_readers: Vec, - postings: Vec, - term_frequencies: Vec, - queue: BinaryHeap, - doc: DocId, - similarity: TAccumulator, - filter: OccurFilter, -} - -impl DAATMultiTermScorer { - - fn new_non_empty( - - fieldnorm_readers: Vec, - postings: Vec, - similarity: TAccumulator, - filter: OccurFilter - ) -> DAATMultiTermScorer { - let mut term_frequencies: Vec = iter::repeat(0u32).take(postings.len()).collect(); - let heap_items: Vec = postings - .iter() - .map(|posting| { - (posting.doc(), posting.term_freq()) - }) - .enumerate() - .map(|(ord, (doc, tf))| { - term_frequencies[ord] = tf; - HeapItem { - doc: doc, - ord: ord as u32 - } - }) - .collect(); - DAATMultiTermScorer { - fieldnorm_readers: fieldnorm_readers, - postings: postings, - term_frequencies: term_frequencies, - queue: BinaryHeap::from(heap_items), - doc: 0, - similarity: similarity, - filter: filter - } - } - - /// Constructor - pub fn new(postings_and_fieldnorms: Vec<(Occur, TPostings, U32FastFieldReader)>, similarity: TAccumulator) -> DAATMultiTermScorer { - let mut postings = Vec::new(); - let mut fieldnorm_readers = Vec::new(); - let mut occurs = Vec::new(); - for (occur, mut posting, fieldnorm_reader) in postings_and_fieldnorms { - if posting.advance() { - postings.push(posting); - fieldnorm_readers.push(fieldnorm_reader); - occurs.push(occur); - } - } - let filter = OccurFilter::new(&occurs); - DAATMultiTermScorer::new_non_empty(fieldnorm_readers, postings, similarity, filter) - } - - /// Returns the scorer - pub fn scorer(&self,) -> &TAccumulator { - &self.similarity - } - - /// Advances the head of our heap (the segment postings with the lowest doc) - /// It will also update the new current `DocId` as well as the term frequency - /// associated with the segment postings. - /// - /// After advancing the `SegmentPosting`, the postings is removed from the heap - /// if it has been entirely consumed, or pushed back into the heap. - /// - /// # Panics - /// This method will panic if the head `SegmentPostings` is not empty. - fn advance_head(&mut self,) { - { - let mut mutable_head = self.queue.peek_mut().unwrap(); - let cur_postings = &mut self.postings[mutable_head.ord as usize]; - if cur_postings.advance() { - let doc = cur_postings.doc(); - self.term_frequencies[mutable_head.ord as usize] = cur_postings.term_freq(); - mutable_head.doc = doc; - return; - } - - } - self.queue.pop(); - } - - /// Returns the field norm for the segment postings with the given ordinal, - /// and the given document. - fn get_field_norm(&self, ord:usize, doc:DocId) -> u32 { - self.fieldnorm_readers[ord].get(doc) - } - -} - -impl Scorer for DAATMultiTermScorer { - fn score(&self,) -> Score { - self.similarity.score() - } -} - -impl DocSet for DAATMultiTermScorer { - - fn advance(&mut self,) -> bool { - loop { - self.similarity.clear(); - let mut ord_bitset = 0u64; - match self.queue.peek() { - Some(heap_item) => { - self.doc = heap_item.doc; - let ord: usize = heap_item.ord as usize; - let fieldnorm = self.get_field_norm(ord, heap_item.doc); - let tf = self.term_frequencies[ord]; - self.similarity.update(ord, tf, fieldnorm); - ord_bitset |= 1 << ord; - } - None => { - return false; - } - } - self.advance_head(); - while let Some(&HeapItem {doc, ord}) = self.queue.peek() { - if doc == self.doc { - let peek_ord: usize = ord as usize; - let peek_tf = self.term_frequencies[peek_ord]; - let peek_fieldnorm = self.get_field_norm(peek_ord, doc); - self.similarity.update(peek_ord, peek_tf, peek_fieldnorm); - ord_bitset |= 1 << peek_ord; - } - else { - break; - } - self.advance_head(); - } - if self.filter.accept(ord_bitset) { - return true; - } - } - } - - fn doc(&self,) -> DocId { - self.doc - } -} - -#[cfg(test)] -mod tests { - - use super::*; - use postings::{DocSet, VecPostings}; - use query::TfIdf; - use query::Scorer; - use directory::Directory; - use directory::RAMDirectory; - use schema::Field; - use std::path::Path; - use query::Occur; - use fastfield::{U32FastFieldReader, U32FastFieldWriter, FastFieldSerializer}; - - - pub fn create_u32_fastfieldreader(field: Field, vals: Vec) -> U32FastFieldReader { - let mut u32_field_writer = U32FastFieldWriter::new(field); - for val in vals { - u32_field_writer.add_val(val); - } - let path = Path::new("some_path"); - let mut directory = RAMDirectory::create(); - { - let write = directory.open_write(&path).unwrap(); - let mut serializer = FastFieldSerializer::new(write).unwrap(); - u32_field_writer.serialize(&mut serializer).unwrap(); - serializer.close().unwrap(); - } - let read = directory.open_read(&path).unwrap(); - U32FastFieldReader::open(read).unwrap() - } - - fn abs_diff(left: f32, right: f32) -> f32 { - (right - left).abs() - } - - #[test] - pub fn test_daat_scorer() { - let left_fieldnorms = create_u32_fastfieldreader(Field(1), vec!(100,200,300)); - let right_fieldnorms = create_u32_fastfieldreader(Field(2), vec!(15,25,35)); - let left = VecPostings::from(vec!(1, 2, 3)); - let right = VecPostings::from(vec!(1, 3, 8)); - let tfidf = TfIdf::new(vec!(0f32, 1f32, 2f32), vec!(1f32, 4f32)); - let mut daat_scorer = DAATMultiTermScorer::new( - vec!( - (Occur::Should, left, left_fieldnorms), - (Occur::Should, right, right_fieldnorms), - ), - tfidf - ); - assert_eq!(daat_scorer.next(), Some(1u32)); - assert!(abs_diff(daat_scorer.score(), 2.182179f32) < 0.001); - assert_eq!(daat_scorer.next(), Some(2u32)); - assert!(abs_diff(daat_scorer.score(), 0.2236068) < 0.001f32); - assert_eq!(daat_scorer.next(), Some(3u32)); - assert_eq!(daat_scorer.next(), Some(8u32)); - assert!(abs_diff(daat_scorer.score(), 0.8944272f32) < 0.001f32); - assert!(!daat_scorer.advance()); - } - -} - diff --git a/src/query/mod.rs b/src/query/mod.rs index 161f30425..57bdeb1ba 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -13,7 +13,6 @@ mod query_parser; mod explanation; mod tfidf; mod occur; -mod daat_multiterm_scorer; mod similarity; mod weight; mod occur_filter; @@ -26,9 +25,6 @@ pub use self::empty_scorer::EmptyScorer; pub use self::occur_filter::OccurFilter; pub use self::similarity::Similarity; - -pub use self::daat_multiterm_scorer::DAATMultiTermScorer; - pub use self::boolean_query::BooleanQuery; pub use self::occur::Occur; pub use self::query::Query; diff --git a/src/query/term_query/term_scorer.rs b/src/query/term_query/term_scorer.rs index 6ad5c591d..e7c3bf644 100644 --- a/src/query/term_query/term_scorer.rs +++ b/src/query/term_query/term_scorer.rs @@ -27,6 +27,7 @@ impl<'a> Scorer for TermScorer<'a> { fn score(&self,) -> Score { let doc = self.segment_postings.doc(); let field_norm = self.fieldnorm_reader.get(doc); - self.idf * (self.segment_postings.term_freq() as f32 / field_norm as f32).sqrt() + self.idf * (self.segment_postings.term_freq() as f32 / field_norm as f32).sqrt() } } + From c2c65d311d2e05570defe3860e82b28605680133 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Wed, 2 Nov 2016 10:41:07 +0900 Subject: [PATCH 11/19] issue/50 Added SegmentPostingsTestFactory --- src/postings/mod.rs | 7 +- src/postings/segment_postings.rs | 4 +- src/postings/segment_postings_test_factory.rs | 87 +++++++++++ src/postings/segment_postings_tester.rs | 76 ---------- src/postings/serializer.rs | 4 +- src/query/boolean_query/boolean_scorer.rs | 20 +-- src/query/mod.rs | 3 - src/query/tfidf.rs | 143 ------------------ 8 files changed, 105 insertions(+), 239 deletions(-) create mode 100644 src/postings/segment_postings_test_factory.rs delete mode 100644 src/postings/segment_postings_tester.rs delete mode 100644 src/query/tfidf.rs diff --git a/src/postings/mod.rs b/src/postings/mod.rs index 2227353de..e66288552 100644 --- a/src/postings/mod.rs +++ b/src/postings/mod.rs @@ -17,9 +17,8 @@ mod offset_postings; mod freq_handler; mod docset; mod segment_postings_option; -mod segment_postings_tester; +mod segment_postings_test_factory; -pub use self::segment_postings_tester::SegmentPostingsTester; pub use self::docset::{SkipResult, DocSet}; pub use self::offset_postings::OffsetPostings; pub use self::recorder::{Recorder, NothingRecorder, TermFrequencyRecorder, TFAndPositionRecorder}; @@ -37,6 +36,10 @@ pub use self::segment_postings::SegmentPostings; pub use self::intersection::intersection; pub use self::intersection::IntersectionDocSet; pub use self::freq_handler::FreqHandler; + +#[cfg(test)] +pub use self::segment_postings_test_factory::SegmentPostingsTestFactory; + pub use self::segment_postings_option::SegmentPostingsOption; pub use common::HasLen; diff --git a/src/postings/segment_postings.rs b/src/postings/segment_postings.rs index ab12add7d..19b2e2ef6 100644 --- a/src/postings/segment_postings.rs +++ b/src/postings/segment_postings.rs @@ -2,7 +2,7 @@ use compression::{NUM_DOCS_PER_BLOCK, SIMDBlockDecoder}; use DocId; use postings::{Postings, FreqHandler, DocSet, HasLen}; use std::num::Wrapping; - +use std::borrow::Cow; @@ -16,7 +16,7 @@ pub struct SegmentPostings<'a> { doc_offset: u32, block_decoder: SIMDBlockDecoder, freq_handler: FreqHandler, - remaining_data: &'a [u8], + remaining_data: &'a[u8], cur: Wrapping, } diff --git a/src/postings/segment_postings_test_factory.rs b/src/postings/segment_postings_test_factory.rs new file mode 100644 index 000000000..e1aefa480 --- /dev/null +++ b/src/postings/segment_postings_test_factory.rs @@ -0,0 +1,87 @@ +use super::FreqHandler; +use DocId; +use std::mem; +use std::path::{Path, PathBuf}; +use super::SegmentPostings; +use super::serializer::PostingsSerializer; +use schema::{SchemaBuilder, STRING}; +use directory::{RAMDirectory, Directory}; +use schema::Term; + + +const EMPTY_POSITIONS: [DocId; 0] = [0u32; 0]; + +pub struct SegmentPostingsTestFactory { + directory: RAMDirectory, + i: usize, +} + +impl Default for SegmentPostingsTestFactory { + fn default() -> SegmentPostingsTestFactory { + SegmentPostingsTestFactory { + directory: RAMDirectory::create(), + i: 0 + } + } +} + + +//data: Vec, +//len: u32, + +impl SegmentPostingsTestFactory { + pub fn from_data<'a>(&'a self, doc_ids: Vec) -> SegmentPostings<'a> { + let mut schema_builder = SchemaBuilder::default(); + let field = schema_builder.add_text_field("text", STRING); + let schema = schema_builder.build(); + + let postings_path = PathBuf::from(format!("postings{}", self.i)); + let terms_path = PathBuf::from(format!("terms{}", self.i)); + let positions_path = PathBuf::from(format!("positions{}", self.i)); + self.i += 1; + + let mut directory = self.directory.clone(); + let mut postings_serializer = PostingsSerializer::new( + directory.open_write(&terms_path).unwrap(), + directory.open_write(&postings_path).unwrap(), + directory.open_write(&positions_path).unwrap(), + schema + ).unwrap(); + let term = Term::from_field_text(field, "dummy"); + postings_serializer.new_term(&term, doc_ids.len() as u32); + for doc_id in &doc_ids { + postings_serializer.write_doc(*doc_id, 1u32, &EMPTY_POSITIONS); + } + postings_serializer.close_term(); + postings_serializer.close(); + let postings_data = self.directory.open_read(&postings_path).unwrap(); + let ref_postings_data = unsafe { + mem::transmute::<&[u8], &'a [u8]>(postings_data.as_slice()) + }; + SegmentPostings::from_data(doc_ids.len() as u32, ref_postings_data, FreqHandler::new_without_freq()) + } +} + + +#[cfg(test)] +mod tests { + + use super::*; + use postings::DocSet; + + #[test] + pub fn test_segment_postings_tester() { + let segment_postings_tester = SegmentPostingsTestFactory::default(); + let mut postings = segment_postings_tester.from_data(vec!(1,2,17,32)); + assert!(postings.advance()); + assert_eq!(postings.doc(), 1); + assert!(postings.advance()); + assert_eq!(postings.doc(), 2); + assert!(postings.advance()); + assert_eq!(postings.doc(), 17); + assert!(postings.advance()); + assert_eq!(postings.doc(), 32); + assert!(!postings.advance()); + } + +} diff --git a/src/postings/segment_postings_tester.rs b/src/postings/segment_postings_tester.rs deleted file mode 100644 index b34e69003..000000000 --- a/src/postings/segment_postings_tester.rs +++ /dev/null @@ -1,76 +0,0 @@ -use super::FreqHandler; -use DocId; -use std::path::Path; -use super::SegmentPostings; -use super::serializer::PostingsSerializer; -use schema::{SchemaBuilder, STRING}; -use directory::{RAMDirectory, Directory}; -use schema::Term; - - -const EMPTY_POSITIONS: [DocId; 0] = [0u32; 0]; - -pub struct SegmentPostingsTester { - data: Vec, - len: u32, -} - -impl SegmentPostingsTester { - pub fn get(&self) -> SegmentPostings { - SegmentPostings::from_data(self.len, &self.data, FreqHandler::new_without_freq()) - } -} - -impl From> for SegmentPostingsTester { - - fn from(doc_ids: Vec) -> SegmentPostingsTester { - let mut directory = RAMDirectory::create(); - let mut schema_builder = SchemaBuilder::default(); - let field = schema_builder.add_text_field("text", STRING); - let schema = schema_builder.build(); - let mut postings_serializer = PostingsSerializer::new( - directory.open_write(Path::new("terms")).unwrap(), - directory.open_write(Path::new("postings")).unwrap(), - directory.open_write(Path::new("positions")).unwrap(), - schema - ).unwrap(); - let term = Term::from_field_text(field, "dummy"); - postings_serializer.new_term(&term, doc_ids.len() as u32); - for doc_id in &doc_ids { - postings_serializer.write_doc(*doc_id, 1u32, &EMPTY_POSITIONS); - } - postings_serializer.close_term(); - postings_serializer.close(); - let postings_data = directory.open_read(Path::new("postings")).unwrap(); - SegmentPostingsTester { - data: Vec::from(postings_data.as_slice()), - len: doc_ids.len() as u32, - } - } - -} - - - -#[cfg(test)] -mod tests { - - use super::*; - use postings::DocSet; - - #[test] - pub fn test_segment_postings_tester() { - let segment_postings_tester = SegmentPostingsTester::from(vec!(1,2,17,32)); - let mut postings = segment_postings_tester.get(); - assert!(postings.advance()); - assert_eq!(postings.doc(), 1); - assert!(postings.advance()); - assert_eq!(postings.doc(), 2); - assert!(postings.advance()); - assert_eq!(postings.doc(), 17); - assert!(postings.advance()); - assert_eq!(postings.doc(), 32); - assert!(!postings.advance()); - } - -} diff --git a/src/postings/serializer.rs b/src/postings/serializer.rs index 1cf1b5392..673a9059d 100644 --- a/src/postings/serializer.rs +++ b/src/postings/serializer.rs @@ -65,9 +65,7 @@ pub struct PostingsSerializer { } impl PostingsSerializer { - - - + /// Open a new `PostingsSerializer` for the given segment pub fn new( terms_write: WritePtr, diff --git a/src/query/boolean_query/boolean_scorer.rs b/src/query/boolean_query/boolean_scorer.rs index 91b31d192..062487691 100644 --- a/src/query/boolean_query/boolean_scorer.rs +++ b/src/query/boolean_query/boolean_scorer.rs @@ -157,7 +157,6 @@ mod tests { use super::*; use postings::{DocSet, VecPostings}; - use query::TfIdf; use query::Scorer; use query::OccurFilter; use query::term_query::TermScorer; @@ -167,7 +166,7 @@ mod tests { use super::super::ScoreCombiner; use std::path::Path; use query::Occur; - use postings::SegmentPostingsTester; + use postings::SegmentPostingsTestFactory; use postings::Postings; use fastfield::{U32FastFieldReader, U32FastFieldWriter, FastFieldSerializer}; @@ -176,15 +175,18 @@ mod tests { fn abs_diff(left: f32, right: f32) -> f32 { (right - left).abs() } - + + lazy_static! { + static ref segment_postings_test_factory: SegmentPostingsTestFactory = SegmentPostingsTestFactory::default(); + } + #[test] pub fn test_boolean_scorer() { let occurs = vec!(Occur::Should, Occur::Should); let occur_filter = OccurFilter::new(&occurs); - + let left_fieldnorms = U32FastFieldReader::from(vec!(100,200,300)); - let left_tester = SegmentPostingsTester::from(vec!(1, 2, 3)); - let left = left_tester.get(); + let left = segment_postings_test_factory.from_data(vec!(1, 2, 3)); let left_scorer = TermScorer { idf: 1f32, fieldnorm_reader: left_fieldnorms, @@ -192,8 +194,7 @@ mod tests { }; let right_fieldnorms = U32FastFieldReader::from(vec!(15,25,35)); - let right_tester = SegmentPostingsTester::from(vec!(1, 3, 8)); - let right = right_tester.get(); + let right = segment_postings_test_factory.from_data(vec!(1, 3, 8)); let mut right_scorer = TermScorer { idf: 4f32, fieldnorm_reader: right_fieldnorms, @@ -218,8 +219,7 @@ mod tests { let left_fieldnorms = U32FastFieldReader::from(vec!(10, 4)); assert_eq!(left_fieldnorms.get(0), 10); assert_eq!(left_fieldnorms.get(1), 4); - let left_tester = SegmentPostingsTester::from(vec!(1)); - let left = left_tester.get(); + let left = segment_postings_test_factory.from_data(vec!(1)); let mut left_scorer = TermScorer { idf: 0.30685282, // 1f32, fieldnorm_reader: left_fieldnorms, diff --git a/src/query/mod.rs b/src/query/mod.rs index 57bdeb1ba..788aa6760 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -11,7 +11,6 @@ mod similarity_explainer; mod scorer; mod query_parser; mod explanation; -mod tfidf; mod occur; mod similarity; mod weight; @@ -31,8 +30,6 @@ pub use self::query::Query; pub use self::term_query::TermQuery; pub use self::multi_term_query::MultiTermQuery; pub use self::similarity_explainer::SimilarityExplainer; -pub use self::tfidf::TfIdf; - pub use self::scorer::Scorer; pub use self::query_parser::QueryParser; pub use self::explanation::Explanation; diff --git a/src/query/tfidf.rs b/src/query/tfidf.rs deleted file mode 100644 index 749c4663e..000000000 --- a/src/query/tfidf.rs +++ /dev/null @@ -1,143 +0,0 @@ -use Score; -use super::MultiTermAccumulator; -use super::Explanation; -use super::Similarity; - - -/// `TfIdf` is the default pertinence score in tantivy. -/// -/// See [Tf-Idf in the global documentation](https://fulmicoton.gitbooks.io/tantivy-doc/content/tfidf.html) -#[derive(Clone)] -pub struct TfIdf { - coords: Vec, - idf: Vec, - score: f32, - num_fields: usize, - term_names: Option>, //< only here for explain -} - -impl MultiTermAccumulator for TfIdf { - - #[inline] - fn update(&mut self, term_ord: usize, term_freq: u32, fieldnorm: u32) { - assert!(term_freq != 0u32); - self.score += self.term_score(term_ord, term_freq, fieldnorm); - self.num_fields += 1; - } - - #[inline] - fn clear(&mut self,) { - self.score = 0f32; - self.num_fields = 0; - } -} - -impl TfIdf { - /// Constructor - /// * coords - Coords act as a boosting factor for queries - /// containing many terms. The coords must have a length - /// of `num_terms + 1` - /// * idf - idf value for each given term. `idf` must - /// have a length of `num_terms`. - pub fn new(coords: Vec, idf: Vec) -> TfIdf { - TfIdf { - coords: coords, - idf: idf, - score: 0f32, - num_fields: 0, - term_names: None, - } - } - - /// Compute the coord term - fn coord(&self,) -> f32 { - self.coords[self.num_fields] - } - - /// Set the term names for the explain function - pub fn set_term_names(&mut self, term_names: Vec) { - self.term_names = Some(term_names); - } - - /// Return the name for the ordinal `ord` - fn term_name(&self, ord: usize) -> String { - match self.term_names { - Some(ref term_names_vec) => term_names_vec[ord].clone(), - None => format!("Field({})", ord) - } - } - - #[inline] - fn term_score(&self, term_ord: usize, term_freq: u32, field_norm: u32) -> f32 { - (term_freq as f32 / field_norm as f32).sqrt() * self.idf[term_ord] - } -} - -impl Similarity for TfIdf { - - #[inline] - fn score(&self, ) -> Score { - self.score * self.coord() - } - - fn explain(&self, vals: &[(usize, u32, u32)]) -> Explanation { - let score = self.score(); - let mut explanation = Explanation::with_val(score); - let formula_components: Vec = vals.iter() - .map(|&(ord, _, _)| ord) - .map(|ord| format!("", self.term_name(ord))) - .collect(); - let formula = format!(" * ({})", formula_components.join(" + ")); - explanation.set_formula(&formula); - for &(ord, term_freq, field_norm) in vals { - let term_score = self.term_score(ord, term_freq, field_norm); - let term_explanation = explanation.add_child(&self.term_name(ord), term_score); - term_explanation.set_formula(" sqrt( / ) * "); - } - explanation - } -} - - - - -#[cfg(test)] -mod tests { - - use super::*; - use query::MultiTermAccumulator; - use query::Similarity; - - fn abs_diff(left: f32, right: f32) -> f32 { - (right - left).abs() - } - - #[test] - pub fn test_tfidf() { - let mut tfidf = TfIdf::new(vec!(0f32, 1f32, 2f32), vec!(1f32, 4f32)); - { - tfidf.update(0, 1, 1); - assert!(abs_diff(tfidf.score(), 1f32) < 0.001f32); - tfidf.clear(); - } - { - tfidf.update(1, 1, 1); - assert_eq!(tfidf.score(), 4f32); - tfidf.clear(); - } - { - tfidf.update(0, 2, 1); - assert!(abs_diff(tfidf.score(), 1.4142135) < 0.001f32); - tfidf.clear(); - } - { - tfidf.update(0, 1, 1); - tfidf.update(1, 1, 1); - assert_eq!(tfidf.score(), 10f32); - tfidf.clear(); - } - - - } - -} \ No newline at end of file From 6229a927308499e9af2a5ca96ca896cf327538d1 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Wed, 2 Nov 2016 12:54:06 +0900 Subject: [PATCH 12/19] issue/50 Removed SegmentPostingsTestFactory for just using VecPostings --- src/common/mod.rs | 1 + src/fastfield/reader.rs | 12 +++ src/postings/docset.rs | 2 +- src/postings/mod.rs | 4 - src/postings/postings.rs | 17 ---- src/postings/segment_postings.rs | 14 ++- src/postings/segment_postings_test_factory.rs | 87 ------------------- src/query/boolean_query/boolean_query.rs | 11 ++- src/query/boolean_query/boolean_scorer.rs | 51 ++++------- src/query/boolean_query/boolean_weight.rs | 17 ++-- src/query/empty_scorer.rs | 22 ----- src/query/mod.rs | 7 +- src/query/multi_term_query.rs | 9 +- src/query/occur.rs | 6 +- src/query/occur_filter.rs | 11 ++- src/query/scorer.rs | 20 ++++- src/query/term_query/term_query.rs | 19 +++- src/query/term_query/term_scorer.rs | 9 +- src/query/term_query/term_weight.rs | 44 ++++++---- src/query/weight.rs | 9 +- 20 files changed, 153 insertions(+), 219 deletions(-) delete mode 100644 src/postings/segment_postings_test_factory.rs delete mode 100644 src/query/empty_scorer.rs diff --git a/src/common/mod.rs b/src/common/mod.rs index 77ff67edc..d14e17617 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -25,3 +25,4 @@ pub trait HasLen { self.len() == 0 } } + diff --git a/src/fastfield/reader.rs b/src/fastfield/reader.rs index 18046025a..335799565 100644 --- a/src/fastfield/reader.rs +++ b/src/fastfield/reader.rs @@ -13,6 +13,14 @@ use fastfield::FastFieldSerializer; use fastfield::U32FastFieldsWriter; use super::compute_num_bits; + +lazy_static! { + static ref U32_FAST_FIELD_EMPTY: ReadOnlySource = { + let u32_fast_field = U32FastFieldReader::from(Vec::new()); + u32_fast_field._data.clone() + }; +} + pub struct U32FastFieldReader { _data: ReadOnlySource, data_ptr: *const u8, @@ -24,6 +32,10 @@ pub struct U32FastFieldReader { impl U32FastFieldReader { + pub fn empty() -> U32FastFieldReader { + U32FastFieldReader::open(U32_FAST_FIELD_EMPTY.clone()).expect("should always work.") + } + pub fn min_val(&self,) -> u32 { self.min_val } diff --git a/src/postings/docset.rs b/src/postings/docset.rs index db40db619..faf839325 100644 --- a/src/postings/docset.rs +++ b/src/postings/docset.rs @@ -61,7 +61,7 @@ pub trait DocSet { } -impl DocSet for Box { +impl DocSet for Box { fn advance(&mut self,) -> bool { let unboxed: &mut TDocSet = self.borrow_mut(); diff --git a/src/postings/mod.rs b/src/postings/mod.rs index e66288552..0da67d2e5 100644 --- a/src/postings/mod.rs +++ b/src/postings/mod.rs @@ -17,7 +17,6 @@ mod offset_postings; mod freq_handler; mod docset; mod segment_postings_option; -mod segment_postings_test_factory; pub use self::docset::{SkipResult, DocSet}; pub use self::offset_postings::OffsetPostings; @@ -37,9 +36,6 @@ pub use self::intersection::intersection; pub use self::intersection::IntersectionDocSet; pub use self::freq_handler::FreqHandler; -#[cfg(test)] -pub use self::segment_postings_test_factory::SegmentPostingsTestFactory; - pub use self::segment_postings_option::SegmentPostingsOption; pub use common::HasLen; diff --git a/src/postings/postings.rs b/src/postings/postings.rs index 8b964d0a9..ff8038750 100644 --- a/src/postings/postings.rs +++ b/src/postings/postings.rs @@ -1,8 +1,5 @@ use std::borrow::Borrow; use postings::docset::DocSet; -use common::HasLen; - - /// Postings (also called inverted list) /// @@ -51,17 +48,3 @@ impl<'a, TPostings: Postings> Postings for &'a mut TPostings { } - -impl HasLen for Box { - fn len(&self,) -> usize { - let unboxed: &THasLen = self.borrow(); - unboxed.borrow().len() - } -} - -impl<'a> HasLen for &'a HasLen { - fn len(&self,) -> usize { - let unref: &HasLen = *self; - unref.len() - } -} diff --git a/src/postings/segment_postings.rs b/src/postings/segment_postings.rs index 19b2e2ef6..cac9b86c8 100644 --- a/src/postings/segment_postings.rs +++ b/src/postings/segment_postings.rs @@ -2,9 +2,9 @@ use compression::{NUM_DOCS_PER_BLOCK, SIMDBlockDecoder}; use DocId; use postings::{Postings, FreqHandler, DocSet, HasLen}; use std::num::Wrapping; -use std::borrow::Cow; +const EMPTY_DATA: [u8; 0] = [0u8; 0]; /// `SegmentPostings` represents the inverted list or postings associated to /// a term in a `Segment`. @@ -52,6 +52,18 @@ impl<'a> SegmentPostings<'a> { } } + /// Returns an empty segment postings object + pub fn empty() -> SegmentPostings<'static> { + SegmentPostings { + len: 0, + doc_offset: 0, + block_decoder: SIMDBlockDecoder::new(), + freq_handler: FreqHandler::new_without_freq(), + remaining_data: &EMPTY_DATA, + cur: Wrapping(usize::max_value()), + } + } + /// Index within a block is used as an address when /// interacting with the `FreqHandler` fn index_within_block(&self,) -> usize { diff --git a/src/postings/segment_postings_test_factory.rs b/src/postings/segment_postings_test_factory.rs deleted file mode 100644 index e1aefa480..000000000 --- a/src/postings/segment_postings_test_factory.rs +++ /dev/null @@ -1,87 +0,0 @@ -use super::FreqHandler; -use DocId; -use std::mem; -use std::path::{Path, PathBuf}; -use super::SegmentPostings; -use super::serializer::PostingsSerializer; -use schema::{SchemaBuilder, STRING}; -use directory::{RAMDirectory, Directory}; -use schema::Term; - - -const EMPTY_POSITIONS: [DocId; 0] = [0u32; 0]; - -pub struct SegmentPostingsTestFactory { - directory: RAMDirectory, - i: usize, -} - -impl Default for SegmentPostingsTestFactory { - fn default() -> SegmentPostingsTestFactory { - SegmentPostingsTestFactory { - directory: RAMDirectory::create(), - i: 0 - } - } -} - - -//data: Vec, -//len: u32, - -impl SegmentPostingsTestFactory { - pub fn from_data<'a>(&'a self, doc_ids: Vec) -> SegmentPostings<'a> { - let mut schema_builder = SchemaBuilder::default(); - let field = schema_builder.add_text_field("text", STRING); - let schema = schema_builder.build(); - - let postings_path = PathBuf::from(format!("postings{}", self.i)); - let terms_path = PathBuf::from(format!("terms{}", self.i)); - let positions_path = PathBuf::from(format!("positions{}", self.i)); - self.i += 1; - - let mut directory = self.directory.clone(); - let mut postings_serializer = PostingsSerializer::new( - directory.open_write(&terms_path).unwrap(), - directory.open_write(&postings_path).unwrap(), - directory.open_write(&positions_path).unwrap(), - schema - ).unwrap(); - let term = Term::from_field_text(field, "dummy"); - postings_serializer.new_term(&term, doc_ids.len() as u32); - for doc_id in &doc_ids { - postings_serializer.write_doc(*doc_id, 1u32, &EMPTY_POSITIONS); - } - postings_serializer.close_term(); - postings_serializer.close(); - let postings_data = self.directory.open_read(&postings_path).unwrap(); - let ref_postings_data = unsafe { - mem::transmute::<&[u8], &'a [u8]>(postings_data.as_slice()) - }; - SegmentPostings::from_data(doc_ids.len() as u32, ref_postings_data, FreqHandler::new_without_freq()) - } -} - - -#[cfg(test)] -mod tests { - - use super::*; - use postings::DocSet; - - #[test] - pub fn test_segment_postings_tester() { - let segment_postings_tester = SegmentPostingsTestFactory::default(); - let mut postings = segment_postings_tester.from_data(vec!(1,2,17,32)); - assert!(postings.advance()); - assert_eq!(postings.doc(), 1); - assert!(postings.advance()); - assert_eq!(postings.doc(), 2); - assert!(postings.advance()); - assert_eq!(postings.doc(), 17); - assert!(postings.advance()); - assert_eq!(postings.doc(), 32); - assert!(!postings.advance()); - } - -} diff --git a/src/query/boolean_query/boolean_query.rs b/src/query/boolean_query/boolean_query.rs index 2e9c24bcb..abecdb148 100644 --- a/src/query/boolean_query/boolean_query.rs +++ b/src/query/boolean_query/boolean_query.rs @@ -8,7 +8,16 @@ use query::Query; use query::Occur; use query::OccurFilter; - +/// The boolean query combines a set of queries +/// +/// The documents matched by the boolean query are +/// those which +/// * match all of the sub queries associated with the +/// `Must` occurence +/// * match none of the sub queries associated with the +/// `MustNot` occurence. +/// * match at least one of the subqueries that is not +/// a `MustNot` occurence. #[derive(Debug)] pub struct BooleanQuery { clauses: Vec, diff --git a/src/query/boolean_query/boolean_scorer.rs b/src/query/boolean_query/boolean_scorer.rs index 062487691..7f7d2d24f 100644 --- a/src/query/boolean_query/boolean_scorer.rs +++ b/src/query/boolean_query/boolean_scorer.rs @@ -1,6 +1,5 @@ use query::Scorer; use DocId; -use Score; use std::collections::BinaryHeap; use std::cmp::Ordering; use postings::DocSet; @@ -38,17 +37,13 @@ pub struct BooleanScorer { queue: BinaryHeap, doc: DocId, score_combiner: ScoreCombiner, - filter: OccurFilter, + occur_filter: OccurFilter, } impl BooleanScorer { - pub fn set_score_combiner(&mut self, score_combiner: ScoreCombiner) { - self.score_combiner = score_combiner; - } - pub fn new(postings: Vec, - filter: OccurFilter) -> BooleanScorer { + occur_filter: OccurFilter) -> BooleanScorer { let score_combiner = ScoreCombiner::default_for_num_scorers(postings.len()); let mut non_empty_postings: Vec = Vec::new(); for mut posting in postings { @@ -73,7 +68,7 @@ impl BooleanScorer { queue: BinaryHeap::from(heap_items), doc: 0u32, score_combiner: score_combiner, - filter: filter, + occur_filter: occur_filter, } } @@ -131,7 +126,7 @@ impl DocSet for BooleanScorer { } self.advance_head(); } - if self.filter.accept(ord_bitset) { + if self.occur_filter.accept(ord_bitset) { return true; } } @@ -160,33 +155,21 @@ mod tests { use query::Scorer; use query::OccurFilter; use query::term_query::TermScorer; - use directory::Directory; - use directory::RAMDirectory; - use schema::Field; - use super::super::ScoreCombiner; - use std::path::Path; use query::Occur; - use postings::SegmentPostingsTestFactory; - use postings::Postings; - use fastfield::{U32FastFieldReader, U32FastFieldWriter, FastFieldSerializer}; + use fastfield::{U32FastFieldReader}; - - fn abs_diff(left: f32, right: f32) -> f32 { (right - left).abs() } - - lazy_static! { - static ref segment_postings_test_factory: SegmentPostingsTestFactory = SegmentPostingsTestFactory::default(); - } - + #[test] pub fn test_boolean_scorer() { let occurs = vec!(Occur::Should, Occur::Should); let occur_filter = OccurFilter::new(&occurs); let left_fieldnorms = U32FastFieldReader::from(vec!(100,200,300)); - let left = segment_postings_test_factory.from_data(vec!(1, 2, 3)); + + let left = VecPostings::from(vec!(1, 2, 3)); let left_scorer = TermScorer { idf: 1f32, fieldnorm_reader: left_fieldnorms, @@ -194,22 +177,22 @@ mod tests { }; let right_fieldnorms = U32FastFieldReader::from(vec!(15,25,35)); - let right = segment_postings_test_factory.from_data(vec!(1, 3, 8)); - let mut right_scorer = TermScorer { + let right = VecPostings::from(vec!(1, 3, 8)); + + let right_scorer = TermScorer { idf: 4f32, fieldnorm_reader: right_fieldnorms, segment_postings: right, }; - let score_combiner = ScoreCombiner::from(vec!(0f32, 1f32, 2f32)); + let mut boolean_scorer = BooleanScorer::new(vec!(left_scorer, right_scorer), occur_filter); - boolean_scorer.set_score_combiner(score_combiner); assert_eq!(boolean_scorer.next(), Some(1u32)); - assert!(abs_diff(boolean_scorer.score(), 1.7414213) < 0.001); + assert!(abs_diff(boolean_scorer.score(), 0.8707107) < 0.001); assert_eq!(boolean_scorer.next(), Some(2u32)); - assert!(abs_diff(boolean_scorer.score(), 0.057735026) < 0.001f32); + assert!(abs_diff(boolean_scorer.score(), 0.028867513) < 0.001f32); assert_eq!(boolean_scorer.next(), Some(3u32)); assert_eq!(boolean_scorer.next(), Some(8u32)); - assert!(abs_diff(boolean_scorer.score(), 1.0327955) < 0.001f32); + assert!(abs_diff(boolean_scorer.score(), 0.5163978) < 0.001f32); assert!(!boolean_scorer.advance()); } @@ -219,9 +202,9 @@ mod tests { let left_fieldnorms = U32FastFieldReader::from(vec!(10, 4)); assert_eq!(left_fieldnorms.get(0), 10); assert_eq!(left_fieldnorms.get(1), 4); - let left = segment_postings_test_factory.from_data(vec!(1)); + let left = VecPostings::from(vec!(1)); let mut left_scorer = TermScorer { - idf: 0.30685282, // 1f32, + idf: 0.30685282, fieldnorm_reader: left_fieldnorms, segment_postings: left, }; diff --git a/src/query/boolean_query/boolean_weight.rs b/src/query/boolean_query/boolean_weight.rs index ee0c37860..930b47348 100644 --- a/src/query/boolean_query/boolean_weight.rs +++ b/src/query/boolean_query/boolean_weight.rs @@ -1,6 +1,7 @@ use query::Weight; use core::SegmentReader; use query::Scorer; +use super::BooleanScorer; use query::OccurFilter; use Result; @@ -23,11 +24,13 @@ impl BooleanWeight { impl Weight for BooleanWeight { fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result> { - // BooleanScorer { - - // } - panic!(""); - + let sub_scorers: Vec> = try!( + self.weights + .iter() + .map(|weight| weight.scorer(reader)) + .collect() + ); + let boolean_scorer = BooleanScorer::new(sub_scorers, self.occur_filter); + Ok(box boolean_scorer) } - -} \ No newline at end of file +} diff --git a/src/query/empty_scorer.rs b/src/query/empty_scorer.rs deleted file mode 100644 index 0c1e989cf..000000000 --- a/src/query/empty_scorer.rs +++ /dev/null @@ -1,22 +0,0 @@ -use query::Scorer; -use DocSet; -use Score; -use DocId; - -pub struct EmptyScorer; - -impl Scorer for EmptyScorer { - fn score(&self) -> Score { - 0f32 - } -} - -impl DocSet for EmptyScorer { - fn advance(&mut self) -> bool { - false - } - - fn doc(&self) -> DocId { - 0 - } -} diff --git a/src/query/mod.rs b/src/query/mod.rs index 788aa6760..a74d71f88 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -16,13 +16,8 @@ mod similarity; mod weight; mod occur_filter; mod term_query; -mod empty_scorer; - - -pub use self::empty_scorer::EmptyScorer; pub use self::occur_filter::OccurFilter; - pub use self::similarity::Similarity; pub use self::boolean_query::BooleanQuery; pub use self::occur::Occur; @@ -35,4 +30,4 @@ pub use self::query_parser::QueryParser; pub use self::explanation::Explanation; pub use self::multi_term_accumulator::MultiTermAccumulator; pub use self::query_parser::ParsingError; -pub use self::weight::Weight; \ No newline at end of file +pub use self::weight::Weight; diff --git a/src/query/multi_term_query.rs b/src/query/multi_term_query.rs index 511489961..42135678a 100644 --- a/src/query/multi_term_query.rs +++ b/src/query/multi_term_query.rs @@ -11,7 +11,6 @@ use query::occur_filter::OccurFilter; use query::term_query::{TermQuery, TermWeight, TermScorer}; use query::boolean_query::BooleanScorer; - struct MultiTermWeight { weights: Vec, occur_filter: OccurFilter, @@ -21,12 +20,10 @@ struct MultiTermWeight { impl Weight for MultiTermWeight { fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result> { - let mut term_scorers: Vec> = Vec::new(); + let mut term_scorers: Vec> = Vec::new(); for term_weight in &self.weights { - let term_scorer_option = try!(term_weight.specialized_scorer(reader)); - if let Some(term_scorer) = term_scorer_option { - term_scorers.push(term_scorer); - } + let term_scorer = try!(term_weight.specialized_scorer(reader)); + term_scorers.push(term_scorer); } Ok(box BooleanScorer::new(term_scorers, self.occur_filter.clone())) } diff --git a/src/query/occur.rs b/src/query/occur.rs index 86bade98d..1f42b4c63 100644 --- a/src/query/occur.rs +++ b/src/query/occur.rs @@ -2,9 +2,9 @@ /// should be present or must not be present. #[derive(Debug, Clone, Copy, Eq, PartialEq)] pub enum Occur { - /// The term should be present in the document. - /// Document without the term will be considered - /// in scoring as well. + /// For a given document to be considered for scoring, + /// at least one of the document with the Should or the Must + /// Occur constraint must be within the document. Should, /// Document without the term are excluded from the search. Must, diff --git a/src/query/occur_filter.rs b/src/query/occur_filter.rs index 188fc637f..53280fa6c 100644 --- a/src/query/occur_filter.rs +++ b/src/query/occur_filter.rs @@ -1,16 +1,25 @@ use query::Occur; -#[derive(Clone)] + +/// An OccurFilter represents a filter over a bitset of +// at most 64 elements. +/// +/// It wraps some simple bitmask to compute the filter +/// rapidly. +#[derive(Clone, Copy)] pub struct OccurFilter { and_mask: u64, result: u64, } impl OccurFilter { + + /// Returns true if the bitset is matching the occur list. pub fn accept(&self, ord_set: u64) -> bool { (self.and_mask & ord_set) == self.result } + /// Builds an `OccurFilter` from a list of `Occur`. pub fn new(occurs: &[Occur]) -> OccurFilter { let mut and_mask = 0u64; let mut result = 0u64; diff --git a/src/query/scorer.rs b/src/query/scorer.rs index 32ed4b443..75137057d 100644 --- a/src/query/scorer.rs +++ b/src/query/scorer.rs @@ -1,7 +1,10 @@ use DocSet; use collector::Collector; +use std::ops::{Deref, DerefMut}; -/// Scored `DocSet` +/// Scored set of documents matching a query within a specific segment. +/// +/// See [Query](./trait.Query.html). pub trait Scorer: DocSet { /// Returns the score. @@ -9,6 +12,8 @@ pub trait Scorer: DocSet { /// This method will perform a bit of computation and is not cached. fn score(&self,) -> f32; + /// Consumes the complete `DocSet` and + /// push the scored documents to the collector. fn collect(&mut self, collector: &mut Collector) { while self.advance() { collector.collect(self.doc(), self.score()); @@ -16,3 +21,16 @@ pub trait Scorer: DocSet { } } + +impl<'a> Scorer for Box { + fn score(&self,) -> f32 { + self.deref().score() + } + + fn collect(&mut self, collector: &mut Collector) { + let scorer = self.deref_mut(); + while scorer.advance() { + collector.collect(scorer.doc(), scorer.score()); + } + } +} \ No newline at end of file diff --git a/src/query/term_query/term_query.rs b/src/query/term_query/term_query.rs index 6049d8455..11536ac3f 100644 --- a/src/query/term_query/term_query.rs +++ b/src/query/term_query/term_query.rs @@ -6,16 +6,31 @@ use query::Weight; use Searcher; use std::any::Any; +/// A Term query matches all of the documents +/// containing a specific term. +/// +/// The score associated is defined as +/// `idf` * sqrt(`term_freq` / `field norm`) +/// in which : +/// * idf - inverse document frequency. +/// * term_freq - number of occurrences of the term in the field +/// * field norm - number of tokens in the field. #[derive(Debug)] pub struct TermQuery { term: Term, } impl TermQuery { + + /// Returns a weight object. + /// + /// While `.weight(...)` returns a boxed trait object, + /// this method return a specific implementation. + /// This is useful for optimization purpose. pub fn specialized_weight(&self, searcher: &Searcher) -> TermWeight { - let doc_freq = searcher.doc_freq(&self.term); TermWeight { - doc_freq: doc_freq, + num_docs: searcher.num_docs(), + doc_freq: searcher.doc_freq(&self.term), term: self.term.clone() } } diff --git a/src/query/term_query/term_scorer.rs b/src/query/term_query/term_scorer.rs index e7c3bf644..db1f48449 100644 --- a/src/query/term_query/term_scorer.rs +++ b/src/query/term_query/term_scorer.rs @@ -1,18 +1,17 @@ use Score; use DocId; -use postings::SegmentPostings; use fastfield::U32FastFieldReader; use postings::DocSet; use query::Scorer; use postings::Postings; -pub struct TermScorer<'a> { +pub struct TermScorer where TPostings: Postings { pub idf: Score, pub fieldnorm_reader: U32FastFieldReader, - pub segment_postings: SegmentPostings<'a>, + pub segment_postings: TPostings, } -impl<'a> DocSet for TermScorer<'a> { +impl DocSet for TermScorer where TPostings: Postings { fn advance(&mut self,) -> bool { self.segment_postings.advance() @@ -23,7 +22,7 @@ impl<'a> DocSet for TermScorer<'a> { } } -impl<'a> Scorer for TermScorer<'a> { +impl Scorer for TermScorer where TPostings: Postings { fn score(&self,) -> Score { let doc = self.segment_postings.doc(); let field_norm = self.fieldnorm_reader.get(doc); diff --git a/src/query/term_query/term_weight.rs b/src/query/term_query/term_weight.rs index acd5ffe83..4f9990411 100644 --- a/src/query/term_query/term_weight.rs +++ b/src/query/term_query/term_weight.rs @@ -2,12 +2,14 @@ use Term; use query::Weight; use core::SegmentReader; use query::Scorer; -use query::EmptyScorer; use postings::SegmentPostingsOption; +use postings::SegmentPostings; +use fastfield::U32FastFieldReader; use super::term_scorer::TermScorer; use Result; pub struct TermWeight { + pub num_docs: u32, pub doc_freq: u32, pub term: Term, } @@ -16,33 +18,37 @@ pub struct TermWeight { impl Weight for TermWeight { fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result> { - let specialized_scorer_option = try!(self.specialized_scorer(reader)); - match specialized_scorer_option { - Some(term_scorer) => { - Ok(box term_scorer) - } - None => { - Ok(box EmptyScorer) - } - } + let specialized_scorer = try!(self.specialized_scorer(reader)); + Ok(box specialized_scorer) } } impl TermWeight { - pub fn specialized_scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result>> { + fn idf(&self) -> f32 { + 1.0 + (self.num_docs as f32 / (self.doc_freq as f32 + 1.0)).ln() + } + + pub fn specialized_scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result>> { let field = self.term.field(); let fieldnorm_reader = try!(reader.get_fieldnorms_reader(field)); Ok( - reader.read_postings(&self.term, SegmentPostingsOption::Freq) - .map(|segment_postings| - TermScorer { - idf: 1f32 / (self.doc_freq as f32), - fieldnorm_reader: fieldnorm_reader, - segment_postings: segment_postings, - } - ) + reader + .read_postings(&self.term, SegmentPostingsOption::Freq) + .map(|segment_postings| + TermScorer { + idf: self.idf(), + fieldnorm_reader: fieldnorm_reader, + segment_postings: segment_postings, + } + ) + .unwrap_or( + TermScorer { + idf: 1f32, + fieldnorm_reader: U32FastFieldReader::empty(), + segment_postings: SegmentPostings::empty() + }) ) } diff --git a/src/query/weight.rs b/src/query/weight.rs index 27a7afd65..db583a3e4 100644 --- a/src/query/weight.rs +++ b/src/query/weight.rs @@ -2,10 +2,15 @@ use super::Scorer; use Result; use core::SegmentReader; + +/// A Weight is the specialization of a Query +/// for a given set of segments. +/// +/// See [Query](./trait.Query.html). pub trait Weight { - + /// Returns the scorer for the given segment. + /// See [Query](./trait.Query.html). fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result>; - } From 59d1b9e2bbf464e0aa677b2f3e25409906494a21 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Wed, 2 Nov 2016 22:28:08 +0900 Subject: [PATCH 13/19] issue/50 Added phrase query --- src/query/boolean_query/boolean_scorer.rs | 12 ++- src/query/mod.rs | 10 +-- src/query/multi_term_accumulator.rs | 14 --- src/query/multi_term_query/mod.rs | 5 ++ .../multi_term_query.rs | 60 +++++-------- .../multi_term_query/multi_term_weight.rs | 33 +++++++ src/query/phrase_query.rs | 86 ------------------- src/query/phrase_query/mod.rs | 7 ++ src/query/phrase_query/phrase_query.rs | 46 ++++++++++ src/query/phrase_query/phrase_scorer.rs | 47 ++++++++++ src/query/phrase_query/phrase_weight.rs | 27 ++++++ src/query/similarity.rs | 19 ---- src/query/similarity_explainer.rs | 50 ----------- src/query/term_query/term_scorer.rs | 17 ++-- src/query/term_query/term_weight.rs | 4 +- 15 files changed, 211 insertions(+), 226 deletions(-) delete mode 100644 src/query/multi_term_accumulator.rs create mode 100644 src/query/multi_term_query/mod.rs rename src/query/{ => multi_term_query}/multi_term_query.rs (66%) create mode 100644 src/query/multi_term_query/multi_term_weight.rs delete mode 100644 src/query/phrase_query.rs create mode 100644 src/query/phrase_query/mod.rs create mode 100644 src/query/phrase_query/phrase_query.rs create mode 100644 src/query/phrase_query/phrase_scorer.rs create mode 100644 src/query/phrase_query/phrase_weight.rs delete mode 100644 src/query/similarity.rs delete mode 100644 src/query/similarity_explainer.rs diff --git a/src/query/boolean_query/boolean_scorer.rs b/src/query/boolean_query/boolean_scorer.rs index 7f7d2d24f..8f35cdc9e 100644 --- a/src/query/boolean_query/boolean_scorer.rs +++ b/src/query/boolean_query/boolean_scorer.rs @@ -8,7 +8,7 @@ use query::boolean_query::ScoreCombiner; /// Each `HeapItem` represents the head of -/// a segment postings being merged. +/// one of scorer being merged. /// /// * `doc` - is the current doc id for the given segment postings /// * `ord` - is the ordinal used to identify to which segment postings @@ -42,6 +42,10 @@ pub struct BooleanScorer { impl BooleanScorer { + pub fn scorers(&self) -> &[TScorer] { + &self.postings + } + pub fn new(postings: Vec, occur_filter: OccurFilter) -> BooleanScorer { let score_combiner = ScoreCombiner::default_for_num_scorers(postings.len()); @@ -173,7 +177,7 @@ mod tests { let left_scorer = TermScorer { idf: 1f32, fieldnorm_reader: left_fieldnorms, - segment_postings: left, + postings: left, }; let right_fieldnorms = U32FastFieldReader::from(vec!(15,25,35)); @@ -182,7 +186,7 @@ mod tests { let right_scorer = TermScorer { idf: 4f32, fieldnorm_reader: right_fieldnorms, - segment_postings: right, + postings: right, }; let mut boolean_scorer = BooleanScorer::new(vec!(left_scorer, right_scorer), occur_filter); @@ -206,7 +210,7 @@ mod tests { let mut left_scorer = TermScorer { idf: 0.30685282, fieldnorm_reader: left_fieldnorms, - segment_postings: left, + postings: left, }; left_scorer.advance(); assert!(abs_diff(left_scorer.score(), 0.15342641) < 0.001f32); diff --git a/src/query/mod.rs b/src/query/mod.rs index a74d71f88..870c871cd 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -6,28 +6,26 @@ mod query; mod boolean_query; mod multi_term_query; -mod multi_term_accumulator; -mod similarity_explainer; +mod phrase_query; mod scorer; mod query_parser; mod explanation; mod occur; -mod similarity; mod weight; mod occur_filter; mod term_query; + pub use self::occur_filter::OccurFilter; -pub use self::similarity::Similarity; pub use self::boolean_query::BooleanQuery; pub use self::occur::Occur; pub use self::query::Query; pub use self::term_query::TermQuery; +pub use self::phrase_query::PhraseQuery; pub use self::multi_term_query::MultiTermQuery; -pub use self::similarity_explainer::SimilarityExplainer; +pub use self::multi_term_query::MultiTermWeight; pub use self::scorer::Scorer; pub use self::query_parser::QueryParser; pub use self::explanation::Explanation; -pub use self::multi_term_accumulator::MultiTermAccumulator; pub use self::query_parser::ParsingError; pub use self::weight::Weight; diff --git a/src/query/multi_term_accumulator.rs b/src/query/multi_term_accumulator.rs deleted file mode 100644 index abbec0c56..000000000 --- a/src/query/multi_term_accumulator.rs +++ /dev/null @@ -1,14 +0,0 @@ - -/// Accumulator of the matching terms information -pub trait MultiTermAccumulator { - /// Update the accumulator given the information of a term. - /// - term_ord is the term_ordinal - /// - term_freq is the frequency of the term within the document - /// - fieldnorm is the number of tokens associated to the field for this document - /// - /// The term's update do not have to arrive in a specific order. - /// Terms that are not present in the document will not be updated. - fn update(&mut self, term_ord: usize, term_freq: u32, fieldnorm: u32); - /// Resets the accumulator - fn clear(&mut self,); -} diff --git a/src/query/multi_term_query/mod.rs b/src/query/multi_term_query/mod.rs new file mode 100644 index 000000000..38cd209a4 --- /dev/null +++ b/src/query/multi_term_query/mod.rs @@ -0,0 +1,5 @@ +mod multi_term_query; +mod multi_term_weight; + +pub use self::multi_term_query::MultiTermQuery; +pub use self::multi_term_weight::MultiTermWeight; \ No newline at end of file diff --git a/src/query/multi_term_query.rs b/src/query/multi_term_query/multi_term_query.rs similarity index 66% rename from src/query/multi_term_query.rs rename to src/query/multi_term_query/multi_term_query.rs index 42135678a..f7298a65b 100644 --- a/src/query/multi_term_query.rs +++ b/src/query/multi_term_query/multi_term_query.rs @@ -1,33 +1,14 @@ use Result; -use super::Weight; +use query::Weight; use std::any::Any; use schema::Term; +use query::MultiTermWeight; use query::Query; use core::searcher::Searcher; -use core::SegmentReader; -use query::Scorer; use query::occur::Occur; use query::occur_filter::OccurFilter; -use query::term_query::{TermQuery, TermWeight, TermScorer}; -use query::boolean_query::BooleanScorer; +use query::term_query::TermQuery; -struct MultiTermWeight { - weights: Vec, - occur_filter: OccurFilter, -} - - -impl Weight for MultiTermWeight { - - fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result> { - let mut term_scorers: Vec> = Vec::new(); - for term_weight in &self.weights { - let term_scorer = try!(term_weight.specialized_scorer(reader)); - term_scorers.push(term_scorer); - } - Ok(box BooleanScorer::new(term_scorers, self.occur_filter.clone())) - } -} /// Query involving one or more terms. @@ -44,17 +25,7 @@ impl MultiTermQuery { self.occur_terms.len() } -} - - - -impl Query for MultiTermQuery { - - fn as_any(&self) -> &Any { - self - } - - fn weight(&self, searcher: &Searcher) -> Result> { + pub fn specialized_weight(&self, searcher: &Searcher) -> MultiTermWeight { let term_queries: Vec = self.occur_terms .iter() .map(|&(_, ref term)| TermQuery::from(term.clone())) @@ -67,12 +38,23 @@ impl Query for MultiTermQuery { let weights = term_queries.iter() .map(|term_query| term_query.specialized_weight(searcher)) .collect(); - Ok( - Box::new(MultiTermWeight { - weights: weights, - occur_filter: occur_filter, - }) - ) + MultiTermWeight { + weights: weights, + occur_filter: occur_filter, + } + } +} + + + +impl Query for MultiTermQuery { + + fn as_any(&self) -> &Any { + self + } + + fn weight(&self, searcher: &Searcher) -> Result> { + Ok(box self.specialized_weight(searcher)) } } diff --git a/src/query/multi_term_query/multi_term_weight.rs b/src/query/multi_term_query/multi_term_weight.rs new file mode 100644 index 000000000..6e12cd7a8 --- /dev/null +++ b/src/query/multi_term_query/multi_term_weight.rs @@ -0,0 +1,33 @@ +use Result; +use query::Weight; +use core::SegmentReader; +use query::Scorer; +use query::occur_filter::OccurFilter; +use postings::SegmentPostings; +use query::term_query::{TermWeight, TermScorer}; +use query::boolean_query::BooleanScorer; + +pub struct MultiTermWeight { + pub weights: Vec, + pub occur_filter: OccurFilter, +} + +impl MultiTermWeight { + + pub fn specialized_scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result>>> { + let mut term_scorers: Vec> = Vec::new(); + for term_weight in &self.weights { + let term_scorer = try!(term_weight.specialized_scorer(reader)); + term_scorers.push(term_scorer); + } + Ok(BooleanScorer::new(term_scorers, self.occur_filter)) + } + +} + +impl Weight for MultiTermWeight { + + fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result> { + Ok(box try!(self.specialized_scorer(reader))) + } +} \ No newline at end of file diff --git a/src/query/phrase_query.rs b/src/query/phrase_query.rs deleted file mode 100644 index 81e1418f6..000000000 --- a/src/query/phrase_query.rs +++ /dev/null @@ -1,86 +0,0 @@ -use schema::Term; -use query::Query; -use common::TimerTree; -use common::OpenTimer; -use std::io; -use core::searcher::Searcher; -use collector::Collector; -use core::searcher::SegmentLocalId; -use core::SegmentReader; -use postings::Postings; -use postings::SegmentPostings; -use postings::intersection; - -pub struct PhraseQuery { - terms: Vec, -} - -impl Query for PhraseQuery { - - fn search(&self, searcher: &Searcher, collector: &mut Collector) -> io::Result { - let mut timer_tree = TimerTree::default(); - { - let mut search_timer = timer_tree.open("search"); - for (segment_ord, segment_reader) in searcher.segments().iter().enumerate() { - let mut segment_search_timer = search_timer.open("segment_search"); - { - let _ = segment_search_timer.open("set_segment"); - try!(collector.set_segment(segment_ord as SegmentLocalId, &segment_reader)); - } - let mut postings = self.search_segment(segment_reader, segment_search_timer.open("get_postings")); - { - let _collection_timer = segment_search_timer.open("collection"); - while postings.next() { - collector.collect(postings.doc()); - } - } - } - } - Ok(timer_tree) - } -} - -impl PhraseQuery { - pub fn new(terms: Vec) -> PhraseQuery { - PhraseQuery { - terms: terms, - } - } - - fn search_segment<'a, 'b>(&'b self, reader: &'b SegmentReader, mut timer: OpenTimer<'a>) -> Box { - if self.terms.len() == 1 { - match reader.get_term(&self.terms[0]) { - Some(term_info) => { - let postings: SegmentPostings<'b> = reader.read_postings(&term_info); - Box::new(postings) - }, - None => { - Box::new(SegmentPostings::empty()) - }, - } - } else { - let mut segment_postings: Vec = Vec::new(); - { - let mut decode_timer = timer.open("decode_all"); - for term in self.terms.iter() { - match reader.get_term(term) { - Some(term_info) => { - let _decode_one_timer = decode_timer.open("decode_one"); - let segment_posting = reader.read_postings_with_positions(&term_info); - segment_postings.push(segment_posting); - } - None => { - // currently this is a strict intersection. - return Box::new(SegmentPostings::empty()); - } - } - } - } - Box::new(intersection(segment_postings)) - } - } -} - - - - diff --git a/src/query/phrase_query/mod.rs b/src/query/phrase_query/mod.rs new file mode 100644 index 000000000..763e72b45 --- /dev/null +++ b/src/query/phrase_query/mod.rs @@ -0,0 +1,7 @@ +mod phrase_query; +mod phrase_weight; +mod phrase_scorer; + +pub use self::phrase_query::PhraseQuery; +pub use self::phrase_weight::PhraseWeight; +pub use self::phrase_scorer::PhraseScorer; \ No newline at end of file diff --git a/src/query/phrase_query/phrase_query.rs b/src/query/phrase_query/phrase_query.rs new file mode 100644 index 000000000..790e2436d --- /dev/null +++ b/src/query/phrase_query/phrase_query.rs @@ -0,0 +1,46 @@ +use schema::Term; +use query::Query; +use core::searcher::Searcher; +use query::Occur; +use super::PhraseWeight; +use query::MultiTermQuery; +use std::any::Any; +use query::Weight; +use Result; + + +#[derive(Debug)] +pub struct PhraseQuery { + all_terms_query: MultiTermQuery, +} + +impl Query for PhraseQuery { + + + /// Used to make it possible to cast Box + /// into a specific type. This is mostly useful for unit tests. + fn as_any(&self) -> &Any { + self + } + + /// Create the weight associated to a query. + /// + /// See [Weight](./trait.Weight.html). + fn weight(&self, searcher: &Searcher) -> Result> { + let multi_term_weight = self.all_terms_query.specialized_weight(searcher); + Ok(box PhraseWeight::from(multi_term_weight)) + } + +} + +impl PhraseQuery { + pub fn new(terms: Vec) -> PhraseQuery { + assert!(terms.len() > 1); + let occur_terms: Vec<(Occur, Term)> = terms.into_iter() + .map(|term| (Occur::Must, term)) + .collect(); + PhraseQuery { + all_terms_query: MultiTermQuery::from(occur_terms), + } + } +} diff --git a/src/query/phrase_query/phrase_scorer.rs b/src/query/phrase_query/phrase_scorer.rs new file mode 100644 index 000000000..4f3285baf --- /dev/null +++ b/src/query/phrase_query/phrase_scorer.rs @@ -0,0 +1,47 @@ +use query::Scorer; +use DocSet; +use query::term_query::TermScorer; +use query::boolean_query::BooleanScorer; +use postings::SegmentPostings; +use postings::Postings; +use DocId; + +pub struct PhraseScorer<'a> { + pub all_term_scorer: BooleanScorer>> +} + +impl<'a> PhraseScorer<'a> { + fn phrase_match(&self) -> bool { + let scorers = self.all_term_scorer.scorers(); + for scorer in scorers { + let positions = scorer.postings().positions(); + } + true + // self.all_term_scorer.positions(); + // let positions = + + } +} + +impl<'a> DocSet for PhraseScorer<'a> { + fn advance(&mut self,) -> bool { + while self.all_term_scorer.advance() { + if self.phrase_match() { + return true; + } + } + false + } + + fn doc(&self,) -> DocId { + self.all_term_scorer.doc() + } +} + + +impl<'a> Scorer for PhraseScorer<'a> { + fn score(&self,) -> f32 { + self.all_term_scorer.score() + } + +} diff --git a/src/query/phrase_query/phrase_weight.rs b/src/query/phrase_query/phrase_weight.rs new file mode 100644 index 000000000..bd8f271f6 --- /dev/null +++ b/src/query/phrase_query/phrase_weight.rs @@ -0,0 +1,27 @@ +use query::Weight; +use query::Scorer; +use core::SegmentReader; +use super::PhraseScorer; +use query::MultiTermWeight; +use Result; + +pub struct PhraseWeight { + all_term_weight: MultiTermWeight, +} + +impl From for PhraseWeight { + fn from(all_term_weight: MultiTermWeight) -> PhraseWeight { + PhraseWeight { + all_term_weight: all_term_weight + } + } +} + +impl Weight for PhraseWeight { + fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result> { + let all_term_scorer = try!(self.all_term_weight.specialized_scorer(reader)); + Ok(box PhraseScorer { + all_term_scorer: all_term_scorer + }) + } +} diff --git a/src/query/similarity.rs b/src/query/similarity.rs deleted file mode 100644 index 1ac371af7..000000000 --- a/src/query/similarity.rs +++ /dev/null @@ -1,19 +0,0 @@ -use Score; -use query::Explanation; -use query::MultiTermAccumulator; - -/// Similarity score -pub trait Similarity: MultiTermAccumulator { - - /// Compute and returns the similarity score, - /// - /// The results are not cached. - fn score(&self, ) -> Score; - - /// Explain the computation of this similarity given all of - /// terms information. - /// - /// `vals` is an array of `(term_ord, term_freq, field_norm)`. - /// Terms that are not present should not appear in the array. - fn explain(&self, vals: &[(usize, u32, u32)]) -> Explanation; -} diff --git a/src/query/similarity_explainer.rs b/src/query/similarity_explainer.rs deleted file mode 100644 index 996b778a7..000000000 --- a/src/query/similarity_explainer.rs +++ /dev/null @@ -1,50 +0,0 @@ -use Score; -use super::MultiTermAccumulator; -use super::Similarity; -use super::Explanation; - - -/// Wrapper over a similarity used to run `explain` -pub struct SimilarityExplainer { - scorer: TSimilarity, - vals: Vec<(usize, u32, u32)>, -} - -impl SimilarityExplainer { - /// Returns the underlying similary's score explanation. - pub fn explain_score(&self,) -> Explanation { - self.scorer.explain(&self.vals) - } -} - -impl From for SimilarityExplainer { - fn from(multi_term_scorer: TSimilarity) -> SimilarityExplainer { - SimilarityExplainer { - scorer: multi_term_scorer, - vals: Vec::new(), - } - } -} - -impl MultiTermAccumulator for SimilarityExplainer { - fn update(&mut self, term_ord: usize, term_freq: u32, fieldnorm: u32) { - self.vals.push((term_ord, term_freq, fieldnorm)); - self.scorer.update(term_ord, term_freq, fieldnorm); - } - - fn clear(&mut self,) { - self.vals.clear(); - self.scorer.clear(); - } -} - -impl Similarity for SimilarityExplainer { - - fn score(&self,) -> Score { - self.scorer.score() - } - - fn explain(&self, vals: &[(usize, u32, u32)]) -> Explanation { - self.scorer.explain(vals) - } -} diff --git a/src/query/term_query/term_scorer.rs b/src/query/term_query/term_scorer.rs index db1f48449..c12b24174 100644 --- a/src/query/term_query/term_scorer.rs +++ b/src/query/term_query/term_scorer.rs @@ -8,25 +8,30 @@ use postings::Postings; pub struct TermScorer where TPostings: Postings { pub idf: Score, pub fieldnorm_reader: U32FastFieldReader, - pub segment_postings: TPostings, + pub postings: TPostings, +} + +impl TermScorer where TPostings: Postings { + pub fn postings(&self) -> &TPostings { + &self.postings + } } impl DocSet for TermScorer where TPostings: Postings { - fn advance(&mut self,) -> bool { - self.segment_postings.advance() + self.postings.advance() } fn doc(&self,) -> DocId { - self.segment_postings.doc() + self.postings.doc() } } impl Scorer for TermScorer where TPostings: Postings { fn score(&self,) -> Score { - let doc = self.segment_postings.doc(); + let doc = self.postings.doc(); let field_norm = self.fieldnorm_reader.get(doc); - self.idf * (self.segment_postings.term_freq() as f32 / field_norm as f32).sqrt() + self.idf * (self.postings.term_freq() as f32 / field_norm as f32).sqrt() } } diff --git a/src/query/term_query/term_weight.rs b/src/query/term_query/term_weight.rs index 4f9990411..9637d1a91 100644 --- a/src/query/term_query/term_weight.rs +++ b/src/query/term_query/term_weight.rs @@ -40,14 +40,14 @@ impl TermWeight { TermScorer { idf: self.idf(), fieldnorm_reader: fieldnorm_reader, - segment_postings: segment_postings, + postings: segment_postings, } ) .unwrap_or( TermScorer { idf: 1f32, fieldnorm_reader: U32FastFieldReader::empty(), - segment_postings: SegmentPostings::empty() + postings: SegmentPostings::empty() }) ) } From a2c6ec93e0e70ca12b634dc9e657ee5f73e7b3f6 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Thu, 3 Nov 2016 14:28:14 +0900 Subject: [PATCH 14/19] issue/50 Fixed VecPostings... Changed intersections. --- src/postings/freq_handler.rs | 4 +- src/postings/intersection.rs | 88 +++++++------------ src/postings/mod.rs | 49 +++++++++-- src/postings/segment_postings_option.rs | 3 +- src/postings/vec_postings.rs | 73 ++++++--------- src/query/boolean_query/boolean_scorer.rs | 73 +-------------- src/query/boolean_query/mod.rs | 71 ++++++++++++++- .../multi_term_query/multi_term_query.rs | 7 +- src/query/phrase_query/mod.rs | 63 ++++++++++++- src/query/phrase_query/phrase_query.rs | 5 +- src/query/phrase_query/phrase_scorer.rs | 43 +++++++-- src/query/phrase_query/phrase_weight.rs | 4 +- src/query/term_query/term_query.rs | 4 +- src/query/term_query/term_weight.rs | 5 +- 14 files changed, 291 insertions(+), 201 deletions(-) diff --git a/src/postings/freq_handler.rs b/src/postings/freq_handler.rs index ea9cb6ae6..660ac4977 100644 --- a/src/postings/freq_handler.rs +++ b/src/postings/freq_handler.rs @@ -95,7 +95,9 @@ impl FreqHandler { pub fn positions(&self, idx: usize) -> &[u32] { let start = self.positions_offsets[idx]; let stop = self.positions_offsets[idx + 1]; - &self.positions[start..stop] + println!("{} -> {}", start, stop); + println!("{} {:?}", idx, &self.positions_offsets[..10]); + &self.positions[start..stop] } /// Decompresses a complete frequency block diff --git a/src/postings/intersection.rs b/src/postings/intersection.rs index 68cb1ec3b..2f49910ed 100644 --- a/src/postings/intersection.rs +++ b/src/postings/intersection.rs @@ -1,90 +1,64 @@ use postings::DocSet; +use postings::SkipResult; use std::cmp::Ordering; use DocId; // TODO Find a way to specialize `IntersectionDocSet` /// Creates a `DocSet` that iterator through the intersection of two `DocSet`s. -pub struct IntersectionDocSet<'a> { - left: Box, - right: Box, +pub struct IntersectionDocSet { + docsets: Vec, finished: bool, + doc: DocId, } -impl<'a> IntersectionDocSet<'a> { - - /// Intersect two `DocSet`s - fn from_pair(left: Box, right: Box) -> IntersectionDocSet<'a> { +impl From> for IntersectionDocSet { + fn from(docsets: Vec) -> IntersectionDocSet { + assert!(docsets.len() >= 2); IntersectionDocSet { - left: left, - right: right, + docsets: docsets, finished: false, - } - } - - /// Intersect a list of `DocSet`s - pub fn new(mut postings: Vec>) -> IntersectionDocSet<'a> { - let left = postings.pop().unwrap(); - let right = - if postings.len() == 1 { - postings.pop().unwrap() - } - else { - Box::new(IntersectionDocSet::new(postings)) - }; - IntersectionDocSet::from_pair(left, right) + doc: DocId::max_value(), + } } } -impl<'a> DocSet for IntersectionDocSet<'a> { +impl DocSet for IntersectionDocSet { fn advance(&mut self,) -> bool { if self.finished { return false; } - - if !self.left.advance() { - self.finished = true; - return false; - } - if !self.right.advance() { - self.finished = true; - return false; - } - loop { - match self.left.doc().cmp(&self.right.doc()) { - Ordering::Equal => { - return true; + + 'outter: loop { + let doc_candidate = { + let mut first_docset = &mut self.docsets[0]; + if !first_docset.advance() { + self.finished = true; + return false; } - Ordering::Less => { - if !self.left.advance() { - self.finished = true; - return false; - } - } - Ordering::Greater => { - if !self.right.advance() { + first_docset.doc() + }; + for docset_ord in 1..self.docsets.len() { + let docset: &mut TDocSet = &mut self.docsets[docset_ord]; + match docset.skip_next(doc_candidate) { + SkipResult::End => { self.finished = true; return false; } + SkipResult::OverStep => { + continue 'outter; + }, + SkipResult::Reached => {} } } + self.doc = doc_candidate; + return true; } } fn doc(&self,) -> DocId { - self.left.doc() + self.doc } } - -/// Intersects a `Vec` of `DocSets` -pub fn intersection<'a, TDocSet: DocSet + 'a>(postings: Vec) -> IntersectionDocSet<'a> { - let boxed_postings: Vec> = postings - .into_iter() - .map(|postings: TDocSet| { - Box::new(postings) as Box - }) - .collect(); - IntersectionDocSet::new(boxed_postings) -} diff --git a/src/postings/mod.rs b/src/postings/mod.rs index 0da67d2e5..7992f629e 100644 --- a/src/postings/mod.rs +++ b/src/postings/mod.rs @@ -32,7 +32,6 @@ pub use self::vec_postings::VecPostings; pub use self::chained_postings::ChainedPostings; pub use self::segment_postings::SegmentPostings; -pub use self::intersection::intersection; pub use self::intersection::IntersectionDocSet; pub use self::freq_handler::FreqHandler; @@ -50,6 +49,7 @@ mod tests { use core::Index; use std::iter; use datastruct::stacker::Heap; + use query::{Query, TermQuery}; #[test] @@ -73,7 +73,7 @@ mod tests { } #[test] - pub fn test_position_and_fieldnorm_write_fullstack() { + pub fn test_position_and_fieldnorm() { let mut schema_builder = SchemaBuilder::default(); let text_field = schema_builder.add_text_field("text", TEXT); let schema = schema_builder.build(); @@ -154,12 +154,43 @@ mod tests { } } + #[test] + pub fn test_position_and_fieldnorm2() { + let mut schema_builder = SchemaBuilder::default(); + let text_field = schema_builder.add_text_field("text", TEXT); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + { + let mut index_writer = index.writer_with_num_threads(1, 40_000_000).unwrap(); + { + let mut doc = Document::default(); + doc.add_text(text_field, "g b b d c g c"); + index_writer.add_document(doc).unwrap(); + } + { + let mut doc = Document::default(); + doc.add_text(text_field, "g a b b a d c g c"); + index_writer.add_document(doc).unwrap(); + } + assert!(index_writer.commit().is_ok()); + } + let term_query = TermQuery::from(Term::from_field_text(text_field, "a")); + let searcher = index.searcher(); + let mut term_weight = term_query.specialized_weight(&*searcher); + term_weight.segment_postings_options = SegmentPostingsOption::FreqAndPositions; + let segment_reader = &searcher.segment_readers()[0]; + let mut term_scorer = term_weight.specialized_scorer(segment_reader).unwrap(); + assert!(term_scorer.advance()); + assert_eq!(term_scorer.doc(), 1u32); + assert_eq!(term_scorer.postings().positions(), &[1u32, 4]); + } + #[test] fn test_intersection() { { - let left = Box::new(VecPostings::from(vec!(1, 3, 9))); - let right = Box::new(VecPostings::from(vec!(3, 4, 9, 18))); - let mut intersection = IntersectionDocSet::new(vec!(left, right)); + let left = VecPostings::from(vec!(1, 3, 9)); + let right = VecPostings::from(vec!(3, 4, 9, 18)); + let mut intersection = IntersectionDocSet::from(vec!(left, right)); assert!(intersection.advance()); assert_eq!(intersection.doc(), 3); assert!(intersection.advance()); @@ -167,10 +198,10 @@ mod tests { assert!(!intersection.advance()); } { - let a = Box::new(VecPostings::from(vec!(1, 3, 9))); - let b = Box::new(VecPostings::from(vec!(3, 4, 9, 18))); - let c = Box::new(VecPostings::from(vec!(1, 5, 9, 111))); - let mut intersection = IntersectionDocSet::new(vec!(a, b, c)); + let a = VecPostings::from(vec!(1, 3, 9)); + let b = VecPostings::from(vec!(3, 4, 9, 18)); + let c = VecPostings::from(vec!(1, 5, 9, 111)); + let mut intersection = IntersectionDocSet::from(vec!(a, b, c)); assert!(intersection.advance()); assert_eq!(intersection.doc(), 9); assert!(!intersection.advance()); diff --git a/src/postings/segment_postings_option.rs b/src/postings/segment_postings_option.rs index 70dfd97e4..082ea0660 100644 --- a/src/postings/segment_postings_option.rs +++ b/src/postings/segment_postings_option.rs @@ -5,7 +5,8 @@ /// Since decoding information is not free, this makes it possible to /// avoid this extra cost when the information is not required. /// For instance, positions are useful when running phrase queries -/// but useless in other queries, +/// but useless in other queries. +#[derive(Clone, Copy)] pub enum SegmentPostingsOption { /// Only the doc ids are decoded NoFreq, diff --git a/src/postings/vec_postings.rs b/src/postings/vec_postings.rs index 8d4ba0d48..02bcf699d 100644 --- a/src/postings/vec_postings.rs +++ b/src/postings/vec_postings.rs @@ -37,59 +37,29 @@ impl DocSet for VecPostings { } fn skip_next(&mut self, target: DocId) -> SkipResult { - let mut start: usize = self.cursor.0; - match self.doc_ids[start].cmp(&target) { - Ordering::Equal => { - return SkipResult::Reached; - } - Ordering::Greater => { - if self.cursor.0 < self.doc_ids.len() { - return SkipResult::OverStep; + let next_id: usize = (self.cursor + Wrapping( + if self.cursor.0 == usize::max_value() { + 1 } else { - return SkipResult::End; + 0 } - } - Ordering::Less => { - // see below - } - } - - let mut end = self.doc_ids.len(); - - while end - start > 1 { - // find an upper bound - let mut jump = 1; - loop { - let jump_dest = start + jump; - if jump_dest >= end { - // we jump out of bounds - break; + )).0; + for i in next_id .. self.doc_ids.len() { + let doc: DocId = self.doc_ids[i]; + match doc.cmp(&target) { + Ordering::Equal => { + self.cursor = Wrapping(i); + return SkipResult::Reached; } - match self.doc_ids[jump_dest].cmp(&target) { - Ordering::Less => { - // still below the target, let's keep jumping. - start = jump_dest; - jump *= 2; - } - Ordering::Equal => { - self.cursor = Wrapping(jump_dest); - return SkipResult::Reached; - } - Ordering::Greater => { - end = jump_dest; - break; - } - } + Ordering::Greater => { + self.cursor = Wrapping(i); + return SkipResult::OverStep; + } + Ordering::Less => {} } } - self.cursor = Wrapping(start + 1); - if self.cursor.0 < self.doc_ids.len() { - SkipResult::OverStep - } - else { - SkipResult::End - } + SkipResult::End } } @@ -132,5 +102,14 @@ pub mod tests { assert_eq!(postings.doc(), 300u32); assert_eq!(postings.skip_next(6000u32), SkipResult::End); } + + #[test] + pub fn test_vec_postings_skip_without_advance() { + let doc_ids: Vec = (0u32..1024u32).map(|e| e*3).collect(); + let mut postings = VecPostings::from(doc_ids); + assert_eq!(postings.skip_next(300u32), SkipResult::Reached); + assert_eq!(postings.doc(), 300u32); + assert_eq!(postings.skip_next(6000u32), SkipResult::End); + } } diff --git a/src/query/boolean_query/boolean_scorer.rs b/src/query/boolean_query/boolean_scorer.rs index 8f35cdc9e..4be2cd7c9 100644 --- a/src/query/boolean_query/boolean_scorer.rs +++ b/src/query/boolean_query/boolean_scorer.rs @@ -77,6 +77,10 @@ impl BooleanScorer { } } + pub fn num_subscorers(&self) -> usize { + self.postings.len() + } + /// Advances the head of our heap (the segment postings with the lowest doc) /// It will also update the new current `DocId` as well as the term frequency @@ -148,72 +152,3 @@ impl Scorer for BooleanScorer { } } - - - -#[cfg(test)] -mod tests { - - use super::*; - use postings::{DocSet, VecPostings}; - use query::Scorer; - use query::OccurFilter; - use query::term_query::TermScorer; - use query::Occur; - use fastfield::{U32FastFieldReader}; - - fn abs_diff(left: f32, right: f32) -> f32 { - (right - left).abs() - } - - #[test] - pub fn test_boolean_scorer() { - let occurs = vec!(Occur::Should, Occur::Should); - let occur_filter = OccurFilter::new(&occurs); - - let left_fieldnorms = U32FastFieldReader::from(vec!(100,200,300)); - - let left = VecPostings::from(vec!(1, 2, 3)); - let left_scorer = TermScorer { - idf: 1f32, - fieldnorm_reader: left_fieldnorms, - postings: left, - }; - - let right_fieldnorms = U32FastFieldReader::from(vec!(15,25,35)); - let right = VecPostings::from(vec!(1, 3, 8)); - - let right_scorer = TermScorer { - idf: 4f32, - fieldnorm_reader: right_fieldnorms, - postings: right, - }; - - let mut boolean_scorer = BooleanScorer::new(vec!(left_scorer, right_scorer), occur_filter); - assert_eq!(boolean_scorer.next(), Some(1u32)); - assert!(abs_diff(boolean_scorer.score(), 0.8707107) < 0.001); - assert_eq!(boolean_scorer.next(), Some(2u32)); - assert!(abs_diff(boolean_scorer.score(), 0.028867513) < 0.001f32); - assert_eq!(boolean_scorer.next(), Some(3u32)); - assert_eq!(boolean_scorer.next(), Some(8u32)); - assert!(abs_diff(boolean_scorer.score(), 0.5163978) < 0.001f32); - assert!(!boolean_scorer.advance()); - } - - - #[test] - pub fn test_term_scorer() { - let left_fieldnorms = U32FastFieldReader::from(vec!(10, 4)); - assert_eq!(left_fieldnorms.get(0), 10); - assert_eq!(left_fieldnorms.get(1), 4); - let left = VecPostings::from(vec!(1)); - let mut left_scorer = TermScorer { - idf: 0.30685282, - fieldnorm_reader: left_fieldnorms, - postings: left, - }; - left_scorer.advance(); - assert!(abs_diff(left_scorer.score(), 0.15342641) < 0.001f32); - } - -} diff --git a/src/query/boolean_query/mod.rs b/src/query/boolean_query/mod.rs index 3f19cb92e..43d0e3336 100644 --- a/src/query/boolean_query/mod.rs +++ b/src/query/boolean_query/mod.rs @@ -7,4 +7,73 @@ mod score_combiner; pub use self::boolean_query::BooleanQuery; pub use self::boolean_clause::BooleanClause; pub use self::boolean_scorer::BooleanScorer; -pub use self::score_combiner::ScoreCombiner; \ No newline at end of file +pub use self::score_combiner::ScoreCombiner; + + + +#[cfg(test)] +mod tests { + + use super::*; + use postings::{DocSet, VecPostings}; + use query::Scorer; + use query::OccurFilter; + use query::term_query::TermScorer; + use query::Occur; + use fastfield::{U32FastFieldReader}; + + fn abs_diff(left: f32, right: f32) -> f32 { + (right - left).abs() + } + + #[test] + pub fn test_boolean_scorer() { + let occurs = vec!(Occur::Should, Occur::Should); + let occur_filter = OccurFilter::new(&occurs); + + let left_fieldnorms = U32FastFieldReader::from(vec!(100,200,300)); + + let left = VecPostings::from(vec!(1, 2, 3)); + let left_scorer = TermScorer { + idf: 1f32, + fieldnorm_reader: left_fieldnorms, + postings: left, + }; + + let right_fieldnorms = U32FastFieldReader::from(vec!(15,25,35)); + let right = VecPostings::from(vec!(1, 3, 8)); + + let right_scorer = TermScorer { + idf: 4f32, + fieldnorm_reader: right_fieldnorms, + postings: right, + }; + + let mut boolean_scorer = BooleanScorer::new(vec!(left_scorer, right_scorer), occur_filter); + assert_eq!(boolean_scorer.next(), Some(1u32)); + assert!(abs_diff(boolean_scorer.score(), 0.8707107) < 0.001); + assert_eq!(boolean_scorer.next(), Some(2u32)); + assert!(abs_diff(boolean_scorer.score(), 0.028867513) < 0.001f32); + assert_eq!(boolean_scorer.next(), Some(3u32)); + assert_eq!(boolean_scorer.next(), Some(8u32)); + assert!(abs_diff(boolean_scorer.score(), 0.5163978) < 0.001f32); + assert!(!boolean_scorer.advance()); + } + + + #[test] + pub fn test_term_scorer() { + let left_fieldnorms = U32FastFieldReader::from(vec!(10, 4)); + assert_eq!(left_fieldnorms.get(0), 10); + assert_eq!(left_fieldnorms.get(1), 4); + let left = VecPostings::from(vec!(1)); + let mut left_scorer = TermScorer { + idf: 0.30685282, + fieldnorm_reader: left_fieldnorms, + postings: left, + }; + left_scorer.advance(); + assert!(abs_diff(left_scorer.score(), 0.15342641) < 0.001f32); + } + +} diff --git a/src/query/multi_term_query/multi_term_query.rs b/src/query/multi_term_query/multi_term_query.rs index f7298a65b..c5e410835 100644 --- a/src/query/multi_term_query/multi_term_query.rs +++ b/src/query/multi_term_query/multi_term_query.rs @@ -8,6 +8,7 @@ use core::searcher::Searcher; use query::occur::Occur; use query::occur_filter::OccurFilter; use query::term_query::TermQuery; +use postings::SegmentPostingsOption; /// Query involving one or more terms. @@ -36,7 +37,11 @@ impl MultiTermQuery { .collect(); let occur_filter = OccurFilter::new(&occurs); let weights = term_queries.iter() - .map(|term_query| term_query.specialized_weight(searcher)) + .map(|term_query| { + let mut term_weight = term_query.specialized_weight(searcher); + term_weight.segment_postings_options = SegmentPostingsOption::FreqAndPositions; + term_weight + }) .collect(); MultiTermWeight { weights: weights, diff --git a/src/query/phrase_query/mod.rs b/src/query/phrase_query/mod.rs index 763e72b45..12fe1f6b5 100644 --- a/src/query/phrase_query/mod.rs +++ b/src/query/phrase_query/mod.rs @@ -4,4 +4,65 @@ mod phrase_scorer; pub use self::phrase_query::PhraseQuery; pub use self::phrase_weight::PhraseWeight; -pub use self::phrase_scorer::PhraseScorer; \ No newline at end of file +pub use self::phrase_scorer::PhraseScorer; + + +#[cfg(test)] +mod tests { + + use super::*; + use query::Query; + use core::Index; + use schema::{Document, Term, SchemaBuilder, TEXT}; + use collector::tests::TestCollector; + + + #[test] + pub fn test_phrase_query() { + + let mut schema_builder = SchemaBuilder::default(); + let text_field = schema_builder.add_text_field("text", TEXT); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + { + let mut index_writer = index.writer_with_num_threads(1, 40_000_000).unwrap(); + { + let mut doc = Document::default(); + doc.add_text(text_field, "a b b d c g c"); + index_writer.add_document(doc).unwrap(); + } + // { + // let mut doc = Document::default(); + // doc.add_text(text_field, "a b a b c"); + // index_writer.add_document(doc).unwrap(); + // } + // { + // let mut doc = Document::default(); + // doc.add_text(text_field, "c a b a d ga a"); + // index_writer.add_document(doc).unwrap(); + // } + // { + // let mut doc = Document::default(); + // doc.add_text(text_field, "a b c"); + // index_writer.add_document(doc).unwrap(); + // } + assert!(index_writer.commit().is_ok()); + } + let mut test_collector = TestCollector::default(); + let build_query = |texts: Vec<&str>| { + let terms: Vec = texts + .iter() + .map(|text| { + Term::from_field_text(text_field, text) + }) + .collect(); + PhraseQuery::from(terms) + }; + let phrase_query = build_query(vec!("a", "b")); + let searcher = index.searcher(); + phrase_query.search(&*searcher, &mut test_collector).expect("search should succeed"); + assert_eq!(test_collector.docs(), vec!(0, 1, 2, 3)); + } + + +} diff --git a/src/query/phrase_query/phrase_query.rs b/src/query/phrase_query/phrase_query.rs index 790e2436d..9058f96b8 100644 --- a/src/query/phrase_query/phrase_query.rs +++ b/src/query/phrase_query/phrase_query.rs @@ -33,8 +33,9 @@ impl Query for PhraseQuery { } -impl PhraseQuery { - pub fn new(terms: Vec) -> PhraseQuery { + +impl From> for PhraseQuery { + fn from(terms: Vec) -> PhraseQuery { assert!(terms.len() > 1); let occur_terms: Vec<(Occur, Term)> = terms.into_iter() .map(|term| (Occur::Must, term)) diff --git a/src/query/phrase_query/phrase_scorer.rs b/src/query/phrase_query/phrase_scorer.rs index 4f3285baf..b80422b94 100644 --- a/src/query/phrase_query/phrase_scorer.rs +++ b/src/query/phrase_query/phrase_scorer.rs @@ -7,24 +7,51 @@ use postings::Postings; use DocId; pub struct PhraseScorer<'a> { - pub all_term_scorer: BooleanScorer>> + pub all_term_scorer: BooleanScorer>>, + pub positions_offsets: Vec, } impl<'a> PhraseScorer<'a> { fn phrase_match(&self) -> bool { - let scorers = self.all_term_scorer.scorers(); - for scorer in scorers { - let positions = scorer.postings().positions(); + println!("phrase_match"); + let mut positions_arr: Vec<&[u32]> = self.all_term_scorer + .scorers() + .iter() + .map(|scorer| { + println!("{:?}", scorer.doc()); + scorer.postings().positions() + }) + .collect(); + println!("positions arr {:?}", positions_arr); + let mut cur = 0; + 'outer: loop { + for i in 0..positions_arr.len() { + let positions: &mut &[u32] = &mut positions_arr[i]; + println!("{} {:?} {:?}", i, positions, self.positions_offsets); + if positions.len() == 0 { + return false; + } + let head_position = positions[0] + self.positions_offsets[i]; + println!("cur: {}, head_position {}", cur, head_position); + while head_position < cur { + if positions.len() == 1 { + return false; + } + *positions = &(*positions)[1..]; + } + if head_position != cur { + cur = head_position; + continue 'outer; + } + } + return true; } - true - // self.all_term_scorer.positions(); - // let positions = - } } impl<'a> DocSet for PhraseScorer<'a> { fn advance(&mut self,) -> bool { + println!("docset advance"); while self.all_term_scorer.advance() { if self.phrase_match() { return true; diff --git a/src/query/phrase_query/phrase_weight.rs b/src/query/phrase_query/phrase_weight.rs index bd8f271f6..b3d1e002c 100644 --- a/src/query/phrase_query/phrase_weight.rs +++ b/src/query/phrase_query/phrase_weight.rs @@ -20,8 +20,10 @@ impl From for PhraseWeight { impl Weight for PhraseWeight { fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result> { let all_term_scorer = try!(self.all_term_weight.specialized_scorer(reader)); + let positions_offsets: Vec = (0u32..all_term_scorer.num_subscorers() as u32).collect(); Ok(box PhraseScorer { - all_term_scorer: all_term_scorer + all_term_scorer: all_term_scorer, + positions_offsets: positions_offsets }) } } diff --git a/src/query/term_query/term_query.rs b/src/query/term_query/term_query.rs index 11536ac3f..3ec748c25 100644 --- a/src/query/term_query/term_query.rs +++ b/src/query/term_query/term_query.rs @@ -3,6 +3,7 @@ use Result; use super::term_weight::TermWeight; use query::Query; use query::Weight; +use postings::SegmentPostingsOption; use Searcher; use std::any::Any; @@ -31,7 +32,8 @@ impl TermQuery { TermWeight { num_docs: searcher.num_docs(), doc_freq: searcher.doc_freq(&self.term), - term: self.term.clone() + term: self.term.clone(), + segment_postings_options: SegmentPostingsOption::NoFreq, } } } diff --git a/src/query/term_query/term_weight.rs b/src/query/term_query/term_weight.rs index 9637d1a91..9d7bac3ee 100644 --- a/src/query/term_query/term_weight.rs +++ b/src/query/term_query/term_weight.rs @@ -11,7 +11,8 @@ use Result; pub struct TermWeight { pub num_docs: u32, pub doc_freq: u32, - pub term: Term, + pub term: Term, + pub segment_postings_options: SegmentPostingsOption, } @@ -35,7 +36,7 @@ impl TermWeight { let fieldnorm_reader = try!(reader.get_fieldnorms_reader(field)); Ok( reader - .read_postings(&self.term, SegmentPostingsOption::Freq) + .read_postings(&self.term, self.segment_postings_options) .map(|segment_postings| TermScorer { idf: self.idf(), From 627e4f1f606802cdca95660a640517e76f346be9 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Thu, 3 Nov 2016 16:56:09 +0900 Subject: [PATCH 15/19] issue/50 Test broken. PhraseQuery uses Intersection DocSet --- src/postings/docset.rs | 8 +++- src/postings/freq_handler.rs | 2 - src/postings/intersection.rs | 58 ++++++++++++++--------- src/postings/mod.rs | 2 +- src/postings/vec_postings.rs | 39 +-------------- src/query/boolean_query/boolean_scorer.rs | 5 -- src/query/phrase_query/mod.rs | 49 +++++++++---------- src/query/phrase_query/phrase_query.rs | 16 ++----- src/query/phrase_query/phrase_scorer.rs | 29 ++++++------ src/query/phrase_query/phrase_weight.rs | 26 ++++++---- 10 files changed, 108 insertions(+), 126 deletions(-) diff --git a/src/postings/docset.rs b/src/postings/docset.rs index faf839325..9c84c2dd8 100644 --- a/src/postings/docset.rs +++ b/src/postings/docset.rs @@ -29,8 +29,14 @@ pub trait DocSet { /// /// SkipResult expresses whether the `target value` was reached, overstepped, /// or if the `DocSet` was entirely consumed without finding any value - /// greater or equal to the `target`. + /// greater or equal to the `target`. + /// + /// WARNING: Calling skip always advances the docset. + /// More specifically, if the docset is already positionned on the target + /// skipping will advance to the next position and return SkipResult::Overstep. + /// fn skip_next(&mut self, target: DocId) -> SkipResult { + self.advance(); loop { match self.doc().cmp(&target) { Ordering::Less => { diff --git a/src/postings/freq_handler.rs b/src/postings/freq_handler.rs index 660ac4977..cb4d6d92e 100644 --- a/src/postings/freq_handler.rs +++ b/src/postings/freq_handler.rs @@ -95,8 +95,6 @@ impl FreqHandler { pub fn positions(&self, idx: usize) -> &[u32] { let start = self.positions_offsets[idx]; let stop = self.positions_offsets[idx + 1]; - println!("{} -> {}", start, stop); - println!("{} {:?}", idx, &self.positions_offsets[..10]); &self.positions[start..stop] } diff --git a/src/postings/intersection.rs b/src/postings/intersection.rs index 2f49910ed..6946a652e 100644 --- a/src/postings/intersection.rs +++ b/src/postings/intersection.rs @@ -1,6 +1,5 @@ use postings::DocSet; use postings::SkipResult; -use std::cmp::Ordering; use DocId; // TODO Find a way to specialize `IntersectionDocSet` @@ -23,38 +22,53 @@ impl From> for IntersectionDocSet { } } +impl IntersectionDocSet { + pub fn docsets(&self) -> &[TDocSet] { + &self.docsets[..] + } +} + impl DocSet for IntersectionDocSet { - + fn advance(&mut self,) -> bool { if self.finished { return false; } - - 'outter: loop { - let doc_candidate = { - let mut first_docset = &mut self.docsets[0]; - if !first_docset.advance() { + let num_docsets = self.docsets.len(); + let mut count_matching = 1; + let mut doc_candidate = { + let mut first_docset = &mut self.docsets[0]; + if !first_docset.advance() { + self.finished = true; + return false; + } + first_docset.doc() + }; + let mut ord = 1; + loop { + let mut doc_set = &mut self.docsets[ord]; + match doc_set.skip_next(doc_candidate) { + SkipResult::Reached => { + count_matching += 1; + if count_matching == num_docsets { + self.doc = doc_candidate; + return true; + } + } + SkipResult::End => { self.finished = true; return false; } - first_docset.doc() - }; - for docset_ord in 1..self.docsets.len() { - let docset: &mut TDocSet = &mut self.docsets[docset_ord]; - match docset.skip_next(doc_candidate) { - SkipResult::End => { - self.finished = true; - return false; - } - SkipResult::OverStep => { - continue 'outter; - }, - SkipResult::Reached => {} + SkipResult::OverStep => { + count_matching = 1; + doc_candidate = doc_set.doc(); } } - self.doc = doc_candidate; - return true; + ord += 1; + if ord == num_docsets { + ord = 0; + } } } diff --git a/src/postings/mod.rs b/src/postings/mod.rs index 7992f629e..a0aa35741 100644 --- a/src/postings/mod.rs +++ b/src/postings/mod.rs @@ -49,7 +49,7 @@ mod tests { use core::Index; use std::iter; use datastruct::stacker::Heap; - use query::{Query, TermQuery}; + use query::TermQuery; #[test] diff --git a/src/postings/vec_postings.rs b/src/postings/vec_postings.rs index 02bcf699d..8704915e0 100644 --- a/src/postings/vec_postings.rs +++ b/src/postings/vec_postings.rs @@ -1,9 +1,8 @@ #![allow(dead_code)] use DocId; -use postings::{Postings, DocSet, SkipResult, HasLen}; +use postings::{Postings, DocSet, HasLen}; use std::num::Wrapping; -use std::cmp::Ordering; const EMPTY_ARRAY: [u32; 0] = []; @@ -35,32 +34,6 @@ impl DocSet for VecPostings { fn doc(&self,) -> DocId { self.doc_ids[self.cursor.0] } - - fn skip_next(&mut self, target: DocId) -> SkipResult { - let next_id: usize = (self.cursor + Wrapping( - if self.cursor.0 == usize::max_value() { - 1 - } - else { - 0 - } - )).0; - for i in next_id .. self.doc_ids.len() { - let doc: DocId = self.doc_ids[i]; - match doc.cmp(&target) { - Ordering::Equal => { - self.cursor = Wrapping(i); - return SkipResult::Reached; - } - Ordering::Greater => { - self.cursor = Wrapping(i); - return SkipResult::OverStep; - } - Ordering::Less => {} - } - } - SkipResult::End - } } impl HasLen for VecPostings { @@ -102,14 +75,6 @@ pub mod tests { assert_eq!(postings.doc(), 300u32); assert_eq!(postings.skip_next(6000u32), SkipResult::End); } - - #[test] - pub fn test_vec_postings_skip_without_advance() { - let doc_ids: Vec = (0u32..1024u32).map(|e| e*3).collect(); - let mut postings = VecPostings::from(doc_ids); - assert_eq!(postings.skip_next(300u32), SkipResult::Reached); - assert_eq!(postings.doc(), 300u32); - assert_eq!(postings.skip_next(6000u32), SkipResult::End); - } + } diff --git a/src/query/boolean_query/boolean_scorer.rs b/src/query/boolean_query/boolean_scorer.rs index 4be2cd7c9..e6c17af69 100644 --- a/src/query/boolean_query/boolean_scorer.rs +++ b/src/query/boolean_query/boolean_scorer.rs @@ -77,11 +77,6 @@ impl BooleanScorer { } } - pub fn num_subscorers(&self) -> usize { - self.postings.len() - } - - /// Advances the head of our heap (the segment postings with the lowest doc) /// It will also update the new current `DocId` as well as the term frequency /// associated with the segment postings. diff --git a/src/query/phrase_query/mod.rs b/src/query/phrase_query/mod.rs index 12fe1f6b5..855d198c2 100644 --- a/src/query/phrase_query/mod.rs +++ b/src/query/phrase_query/mod.rs @@ -9,14 +9,13 @@ pub use self::phrase_scorer::PhraseScorer; #[cfg(test)] mod tests { - + use super::*; use query::Query; use core::Index; use schema::{Document, Term, SchemaBuilder, TEXT}; use collector::tests::TestCollector; - #[test] pub fn test_phrase_query() { @@ -26,43 +25,45 @@ mod tests { let index = Index::create_in_ram(schema); { let mut index_writer = index.writer_with_num_threads(1, 40_000_000).unwrap(); - { + { // 0 + let mut doc = Document::default(); + doc.add_text(text_field, "b b b d c g c"); + index_writer.add_document(doc).unwrap(); + } + { // 1 let mut doc = Document::default(); doc.add_text(text_field, "a b b d c g c"); index_writer.add_document(doc).unwrap(); } - // { - // let mut doc = Document::default(); - // doc.add_text(text_field, "a b a b c"); - // index_writer.add_document(doc).unwrap(); - // } - // { - // let mut doc = Document::default(); - // doc.add_text(text_field, "c a b a d ga a"); - // index_writer.add_document(doc).unwrap(); - // } - // { - // let mut doc = Document::default(); - // doc.add_text(text_field, "a b c"); - // index_writer.add_document(doc).unwrap(); - // } + { // 2 + let mut doc = Document::default(); + doc.add_text(text_field, "a b a b c"); + index_writer.add_document(doc).unwrap(); + } + { // 3 + let mut doc = Document::default(); + doc.add_text(text_field, "c a b a d ga a"); + index_writer.add_document(doc).unwrap(); + } + { // 4 + let mut doc = Document::default(); + doc.add_text(text_field, "a b c"); + index_writer.add_document(doc).unwrap(); + } assert!(index_writer.commit().is_ok()); } let mut test_collector = TestCollector::default(); let build_query = |texts: Vec<&str>| { let terms: Vec = texts .iter() - .map(|text| { - Term::from_field_text(text_field, text) - }) + .map(|text| Term::from_field_text(text_field, text)) .collect(); PhraseQuery::from(terms) }; - let phrase_query = build_query(vec!("a", "b")); + let phrase_query = build_query(vec!("a", "b", "c")); let searcher = index.searcher(); phrase_query.search(&*searcher, &mut test_collector).expect("search should succeed"); - assert_eq!(test_collector.docs(), vec!(0, 1, 2, 3)); + assert_eq!(test_collector.docs(), vec!(1, 2, 4)); } - } diff --git a/src/query/phrase_query/phrase_query.rs b/src/query/phrase_query/phrase_query.rs index 9058f96b8..816fd4dbc 100644 --- a/src/query/phrase_query/phrase_query.rs +++ b/src/query/phrase_query/phrase_query.rs @@ -1,9 +1,7 @@ use schema::Term; use query::Query; use core::searcher::Searcher; -use query::Occur; use super::PhraseWeight; -use query::MultiTermQuery; use std::any::Any; use query::Weight; use Result; @@ -11,7 +9,7 @@ use Result; #[derive(Debug)] pub struct PhraseQuery { - all_terms_query: MultiTermQuery, + phrase_terms: Vec, } impl Query for PhraseQuery { @@ -27,21 +25,17 @@ impl Query for PhraseQuery { /// /// See [Weight](./trait.Weight.html). fn weight(&self, searcher: &Searcher) -> Result> { - let multi_term_weight = self.all_terms_query.specialized_weight(searcher); - Ok(box PhraseWeight::from(multi_term_weight)) + Ok(box PhraseWeight::from(self.phrase_terms.clone())) } } impl From> for PhraseQuery { - fn from(terms: Vec) -> PhraseQuery { - assert!(terms.len() > 1); - let occur_terms: Vec<(Occur, Term)> = terms.into_iter() - .map(|term| (Occur::Must, term)) - .collect(); + fn from(phrase_terms: Vec) -> PhraseQuery { + assert!(phrase_terms.len() > 1); PhraseQuery { - all_terms_query: MultiTermQuery::from(occur_terms), + phrase_terms: phrase_terms, } } } diff --git a/src/query/phrase_query/phrase_scorer.rs b/src/query/phrase_query/phrase_scorer.rs index b80422b94..725099f12 100644 --- a/src/query/phrase_query/phrase_scorer.rs +++ b/src/query/phrase_query/phrase_scorer.rs @@ -1,34 +1,34 @@ use query::Scorer; use DocSet; -use query::term_query::TermScorer; -use query::boolean_query::BooleanScorer; use postings::SegmentPostings; use postings::Postings; +use postings::IntersectionDocSet; use DocId; pub struct PhraseScorer<'a> { - pub all_term_scorer: BooleanScorer>>, + pub intersection_docset: IntersectionDocSet>, pub positions_offsets: Vec, } impl<'a> PhraseScorer<'a> { fn phrase_match(&self) -> bool { - println!("phrase_match"); - let mut positions_arr: Vec<&[u32]> = self.all_term_scorer - .scorers() + let mut positions_arr: Vec<&[u32]> = self.intersection_docset + .docsets() .iter() - .map(|scorer| { - println!("{:?}", scorer.doc()); - scorer.postings().positions() + .map(|posting| { + posting.positions() }) .collect(); println!("positions arr {:?}", positions_arr); + + let mut cur = 0; 'outer: loop { for i in 0..positions_arr.len() { + println!("i {}", i); let positions: &mut &[u32] = &mut positions_arr[i]; - println!("{} {:?} {:?}", i, positions, self.positions_offsets); if positions.len() == 0 { + println!("NOPE"); return false; } let head_position = positions[0] + self.positions_offsets[i]; @@ -51,9 +51,10 @@ impl<'a> PhraseScorer<'a> { impl<'a> DocSet for PhraseScorer<'a> { fn advance(&mut self,) -> bool { - println!("docset advance"); - while self.all_term_scorer.advance() { + while self.intersection_docset.advance() { + println!("doc {}", self.intersection_docset.doc()); if self.phrase_match() { + println!("return {}", self.intersection_docset.doc()); return true; } } @@ -61,14 +62,14 @@ impl<'a> DocSet for PhraseScorer<'a> { } fn doc(&self,) -> DocId { - self.all_term_scorer.doc() + self.intersection_docset.doc() } } impl<'a> Scorer for PhraseScorer<'a> { fn score(&self,) -> f32 { - self.all_term_scorer.score() + 1f32 } } diff --git a/src/query/phrase_query/phrase_weight.rs b/src/query/phrase_query/phrase_weight.rs index b3d1e002c..234e22073 100644 --- a/src/query/phrase_query/phrase_weight.rs +++ b/src/query/phrase_query/phrase_weight.rs @@ -1,29 +1,37 @@ use query::Weight; use query::Scorer; +use schema::Term; +use postings::SegmentPostingsOption; use core::SegmentReader; use super::PhraseScorer; -use query::MultiTermWeight; +use postings::IntersectionDocSet; use Result; pub struct PhraseWeight { - all_term_weight: MultiTermWeight, + phrase_terms: Vec, } -impl From for PhraseWeight { - fn from(all_term_weight: MultiTermWeight) -> PhraseWeight { +impl From> for PhraseWeight { + fn from(phrase_terms: Vec) -> PhraseWeight { PhraseWeight { - all_term_weight: all_term_weight + phrase_terms: phrase_terms } } } impl Weight for PhraseWeight { fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result> { - let all_term_scorer = try!(self.all_term_weight.specialized_scorer(reader)); - let positions_offsets: Vec = (0u32..all_term_scorer.num_subscorers() as u32).collect(); + let mut term_postings_list = Vec::new(); + for term in &self.phrase_terms { + let term_postings_option = reader.read_postings(term, SegmentPostingsOption::FreqAndPositions); + if let Some(term_postings) = term_postings_option { + term_postings_list.push(term_postings); + } + } + let positions_offsets: Vec = (0u32..self.phrase_terms.len() as u32).collect(); Ok(box PhraseScorer { - all_term_scorer: all_term_scorer, - positions_offsets: positions_offsets + intersection_docset: IntersectionDocSet::from(term_postings_list), + positions_offsets: positions_offsets, }) } } From 9d3c9999cb6be5179cf24ddc0e067bc3d8bb1929 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Thu, 3 Nov 2016 23:00:02 +0900 Subject: [PATCH 16/19] issue/50 PhraseQuery working --- src/common/mod.rs | 1 - src/postings/docset.rs | 1 - src/postings/mod.rs | 1 - src/postings/postings.rs | 1 + src/query/mod.rs | 1 + src/query/phrase_query/mod.rs | 19 ++++++---- src/query/phrase_query/phrase_scorer.rs | 48 ++++++++++++++----------- src/query/phrase_query/phrase_weight.rs | 6 ++-- src/query/scorer.rs | 27 ++++++++++++-- 9 files changed, 69 insertions(+), 36 deletions(-) diff --git a/src/common/mod.rs b/src/common/mod.rs index d14e17617..bf38fee07 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -9,7 +9,6 @@ pub use self::timer::OpenTimer; pub use self::vint::VInt; use std::io; - pub fn make_io_err(msg: String) -> io::Error { io::Error::new(io::ErrorKind::Other, msg) } diff --git a/src/postings/docset.rs b/src/postings/docset.rs index 9c84c2dd8..d07c9228b 100644 --- a/src/postings/docset.rs +++ b/src/postings/docset.rs @@ -66,7 +66,6 @@ pub trait DocSet { } } - impl DocSet for Box { fn advance(&mut self,) -> bool { diff --git a/src/postings/mod.rs b/src/postings/mod.rs index a0aa35741..e6c07837a 100644 --- a/src/postings/mod.rs +++ b/src/postings/mod.rs @@ -29,7 +29,6 @@ pub use self::postings::Postings; #[cfg(test)] pub use self::vec_postings::VecPostings; - pub use self::chained_postings::ChainedPostings; pub use self::segment_postings::SegmentPostings; pub use self::intersection::IntersectionDocSet; diff --git a/src/postings/postings.rs b/src/postings/postings.rs index ff8038750..dbd83a997 100644 --- a/src/postings/postings.rs +++ b/src/postings/postings.rs @@ -48,3 +48,4 @@ impl<'a, TPostings: Postings> Postings for &'a mut TPostings { } + diff --git a/src/query/mod.rs b/src/query/mod.rs index 870c871cd..c494b8c68 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -25,6 +25,7 @@ pub use self::phrase_query::PhraseQuery; pub use self::multi_term_query::MultiTermQuery; pub use self::multi_term_query::MultiTermWeight; pub use self::scorer::Scorer; +pub use self::scorer::EmptyScorer; pub use self::query_parser::QueryParser; pub use self::explanation::Explanation; pub use self::query_parser::ParsingError; diff --git a/src/query/phrase_query/mod.rs b/src/query/phrase_query/mod.rs index 855d198c2..d0bb61cfa 100644 --- a/src/query/phrase_query/mod.rs +++ b/src/query/phrase_query/mod.rs @@ -52,18 +52,23 @@ mod tests { } assert!(index_writer.commit().is_ok()); } - let mut test_collector = TestCollector::default(); - let build_query = |texts: Vec<&str>| { + + let searcher = index.searcher(); + let test_query = |texts: Vec<&str>| { + let mut test_collector = TestCollector::default(); let terms: Vec = texts .iter() .map(|text| Term::from_field_text(text_field, text)) .collect(); - PhraseQuery::from(terms) + let phrase_query = PhraseQuery::from(terms); + phrase_query.search(&*searcher, &mut test_collector).expect("search should succeed"); + test_collector.docs() }; - let phrase_query = build_query(vec!("a", "b", "c")); - let searcher = index.searcher(); - phrase_query.search(&*searcher, &mut test_collector).expect("search should succeed"); - assert_eq!(test_collector.docs(), vec!(1, 2, 4)); + assert_eq!(test_query(vec!("a", "b", "c")), vec!(2, 4)); + assert_eq!(test_query(vec!("a", "b")), vec!(1, 2, 3, 4)); + assert_eq!(test_query(vec!("b", "b")), vec!(0, 1)); + assert_eq!(test_query(vec!("g", "ewrwer")), vec!()); + assert_eq!(test_query(vec!("g", "a")), vec!()); } } diff --git a/src/query/phrase_query/phrase_scorer.rs b/src/query/phrase_query/phrase_scorer.rs index 725099f12..d2a6a645f 100644 --- a/src/query/phrase_query/phrase_scorer.rs +++ b/src/query/phrase_query/phrase_scorer.rs @@ -7,9 +7,9 @@ use DocId; pub struct PhraseScorer<'a> { pub intersection_docset: IntersectionDocSet>, - pub positions_offsets: Vec, } + impl<'a> PhraseScorer<'a> { fn phrase_match(&self) -> bool { let mut positions_arr: Vec<&[u32]> = self.intersection_docset @@ -19,32 +19,40 @@ impl<'a> PhraseScorer<'a> { posting.positions() }) .collect(); - println!("positions arr {:?}", positions_arr); + let num_postings = positions_arr.len() as u32; + + let mut ord = 1u32; + let mut pos_candidate = positions_arr[0][0]; + positions_arr[0] = &(positions_arr[0])[1..]; + let mut count_matching = 1; - let mut cur = 0; 'outer: loop { - for i in 0..positions_arr.len() { - println!("i {}", i); - let positions: &mut &[u32] = &mut positions_arr[i]; - if positions.len() == 0 { - println!("NOPE"); - return false; + let target = pos_candidate + ord; + let positions = positions_arr[ord as usize]; + for i in 0..positions.len() { + let pos_i = positions[i]; + if pos_i < target { + continue; } - let head_position = positions[0] + self.positions_offsets[i]; - println!("cur: {}, head_position {}", cur, head_position); - while head_position < cur { - if positions.len() == 1 { - return false; + if pos_i == target { + count_matching += 1; + if count_matching == num_postings { + return true; } - *positions = &(*positions)[1..]; } - if head_position != cur { - cur = head_position; - continue 'outer; + else if pos_i > target { + count_matching = 1; + pos_candidate = positions[i] - ord; + positions_arr[ord as usize] = &(positions_arr[ord as usize])[(i+1)..]; } + ord += 1; + if ord == num_postings { + ord = 0; + } + continue 'outer; } - return true; + return false; } } } @@ -52,9 +60,7 @@ impl<'a> PhraseScorer<'a> { impl<'a> DocSet for PhraseScorer<'a> { fn advance(&mut self,) -> bool { while self.intersection_docset.advance() { - println!("doc {}", self.intersection_docset.doc()); if self.phrase_match() { - println!("return {}", self.intersection_docset.doc()); return true; } } diff --git a/src/query/phrase_query/phrase_weight.rs b/src/query/phrase_query/phrase_weight.rs index 234e22073..d2a384183 100644 --- a/src/query/phrase_query/phrase_weight.rs +++ b/src/query/phrase_query/phrase_weight.rs @@ -5,6 +5,7 @@ use postings::SegmentPostingsOption; use core::SegmentReader; use super::PhraseScorer; use postings::IntersectionDocSet; +use query::EmptyScorer; use Result; pub struct PhraseWeight { @@ -27,11 +28,12 @@ impl Weight for PhraseWeight { if let Some(term_postings) = term_postings_option { term_postings_list.push(term_postings); } + else { + return Ok(box EmptyScorer); + } } - let positions_offsets: Vec = (0u32..self.phrase_terms.len() as u32).collect(); Ok(box PhraseScorer { intersection_docset: IntersectionDocSet::from(term_postings_list), - positions_offsets: positions_offsets, }) } } diff --git a/src/query/scorer.rs b/src/query/scorer.rs index 75137057d..5f21a2ab3 100644 --- a/src/query/scorer.rs +++ b/src/query/scorer.rs @@ -1,4 +1,6 @@ use DocSet; +use DocId; +use Score; use collector::Collector; use std::ops::{Deref, DerefMut}; @@ -10,7 +12,7 @@ pub trait Scorer: DocSet { /// Returns the score. /// /// This method will perform a bit of computation and is not cached. - fn score(&self,) -> f32; + fn score(&self,) -> Score; /// Consumes the complete `DocSet` and /// push the scored documents to the collector. @@ -23,7 +25,7 @@ pub trait Scorer: DocSet { impl<'a> Scorer for Box { - fn score(&self,) -> f32 { + fn score(&self,) -> Score { self.deref().score() } @@ -33,4 +35,23 @@ impl<'a> Scorer for Box { collector.collect(scorer.doc(), scorer.score()); } } -} \ No newline at end of file +} + + +pub struct EmptyScorer; + +impl DocSet for EmptyScorer { + fn advance(&mut self,) -> bool { + false + } + + fn doc(&self,) -> DocId { + DocId::max_value() + } +} + +impl Scorer for EmptyScorer { + fn score(&self,) -> Score { + 0f32 + } +} From f2df0bf0e9a13f372ee84e88d42a43b488257444 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Fri, 4 Nov 2016 00:11:46 +0900 Subject: [PATCH 17/19] issue/50 Small formatting change. --- src/postings/docset.rs | 47 ++++---- src/postings/freq_handler.rs | 36 +++--- src/postings/intersection.rs | 14 ++- src/postings/offset_postings.rs | 22 ++-- src/postings/postings.rs | 19 +--- src/postings/postings_writer.rs | 39 +++---- src/postings/recorder.rs | 90 +++++++-------- src/postings/segment_postings.rs | 41 ++++--- src/postings/segment_postings_option.rs | 4 +- src/postings/serializer.rs | 104 +++++++++--------- src/postings/vec_postings.rs | 29 +++-- src/query/boolean_query/boolean_clause.rs | 4 +- src/query/boolean_query/boolean_query.rs | 17 +-- src/query/boolean_query/boolean_scorer.rs | 30 ++--- src/query/boolean_query/boolean_weight.rs | 16 +-- .../multi_term_query/multi_term_query.rs | 30 ++--- .../multi_term_query/multi_term_weight.rs | 26 +++-- src/query/phrase_query/phrase_query.rs | 15 ++- src/query/scorer.rs | 4 +- 19 files changed, 290 insertions(+), 297 deletions(-) diff --git a/src/postings/docset.rs b/src/postings/docset.rs index d07c9228b..24ae1ffe6 100644 --- a/src/postings/docset.rs +++ b/src/postings/docset.rs @@ -17,23 +17,23 @@ pub enum SkipResult { } -/// Represents an iterable set of sorted doc ids. +/// Represents an iterable set of sorted doc ids. pub trait DocSet { /// Goes to the next element. /// `.advance(...)` needs to be called a first time to point to the correct /// element. - fn advance(&mut self,) -> bool; - + fn advance(&mut self) -> bool; + /// After skipping, position the iterator in such a way that `.doc()` /// will return a value greater than or equal to target. - /// + /// /// SkipResult expresses whether the `target value` was reached, overstepped, /// or if the `DocSet` was entirely consumed without finding any value /// greater or equal to the `target`. /// /// WARNING: Calling skip always advances the docset. /// More specifically, if the docset is already positionned on the target - /// skipping will advance to the next position and return SkipResult::Overstep. + /// skipping will advance to the next position and return SkipResult::Overstep. /// fn skip_next(&mut self, target: DocId) -> SkipResult { self.advance(); @@ -43,32 +43,30 @@ pub trait DocSet { if !self.advance() { return SkipResult::End; } - }, - Ordering::Equal => { return SkipResult::Reached }, - Ordering::Greater => { return SkipResult::OverStep }, + } + Ordering::Equal => return SkipResult::Reached, + Ordering::Greater => return SkipResult::OverStep, } } } - + /// Returns the current document - fn doc(&self,) -> DocId; - + fn doc(&self) -> DocId; + /// Advances the cursor to the next document - /// None is returned if the iterator has `DocSet` - /// has already been entirely consumed. - fn next(&mut self,) -> Option { + /// None is returned if the iterator has `DocSet` + /// has already been entirely consumed. + fn next(&mut self) -> Option { if self.advance() { Some(self.doc()) - } - else { + } else { None } - } + } } impl DocSet for Box { - - fn advance(&mut self,) -> bool { + fn advance(&mut self) -> bool { let unboxed: &mut TDocSet = self.borrow_mut(); unboxed.advance() } @@ -78,28 +76,25 @@ impl DocSet for Box { unboxed.skip_next(target) } - fn doc(&self,) -> DocId { + fn doc(&self) -> DocId { let unboxed: &TDocSet = self.borrow(); unboxed.doc() } } impl<'a, TDocSet: DocSet> DocSet for &'a mut TDocSet { - - fn advance(&mut self,) -> bool { + fn advance(&mut self) -> bool { let unref: &mut TDocSet = *self; unref.advance() } - + fn skip_next(&mut self, target: DocId) -> SkipResult { let unref: &mut TDocSet = *self; unref.skip_next(target) } - fn doc(&self,) -> DocId { + fn doc(&self) -> DocId { let unref: &TDocSet = *self; unref.doc() } } - - diff --git a/src/postings/freq_handler.rs b/src/postings/freq_handler.rs index cb4d6d92e..70808798a 100644 --- a/src/postings/freq_handler.rs +++ b/src/postings/freq_handler.rs @@ -17,7 +17,7 @@ pub struct FreqHandler { fn read_positions(data: &[u8]) -> Vec { - let mut composite_reader = CompositeDecoder::new(); + let mut composite_reader = CompositeDecoder::new(); let mut readable: &[u8] = data; let uncompressed_len = VInt::deserialize(&mut readable).unwrap().0 as usize; composite_reader.uncompress_unsorted(readable, uncompressed_len); @@ -27,17 +27,16 @@ fn read_positions(data: &[u8]) -> Vec { impl FreqHandler { - /// Returns a `FreqHandler` that just decodes `DocId`s. pub fn new_without_freq() -> FreqHandler { FreqHandler { freq_decoder: SIMDBlockDecoder::with_val(1u32), - positions: Vec::new(), + positions: Vec::new(), option: SegmentPostingsOption::NoFreq, positions_offsets: [0; NUM_DOCS_PER_BLOCK + 1], } } - + /// Returns a `FreqHandler` that decodes `DocId`s and term frequencies. pub fn new_with_freq() -> FreqHandler { FreqHandler { @@ -54,15 +53,15 @@ impl FreqHandler { let positions = read_positions(position_data); FreqHandler { freq_decoder: SIMDBlockDecoder::new(), - positions: positions, + positions: positions, option: SegmentPostingsOption::FreqAndPositions, positions_offsets: [0; NUM_DOCS_PER_BLOCK + 1], } } - - fn fill_positions_offset(&mut self,) { + + fn fill_positions_offset(&mut self) { let mut cur_position: usize = self.positions_offsets[NUM_DOCS_PER_BLOCK]; - let mut i: usize = 0; + let mut i: usize = 0; self.positions_offsets[i] = cur_position; let mut last_cur_position = cur_position; for &doc_freq in self.freq_decoder.output_array() { @@ -78,16 +77,16 @@ impl FreqHandler { last_cur_position = cur_position; } } - - + + /// Accessor to term frequency /// /// idx is the offset of the current doc in the block. /// It takes value between 0 and 128. - pub fn freq(&self, idx: usize)-> u32 { + pub fn freq(&self, idx: usize) -> u32 { self.freq_decoder.output(idx) } - + /// Accessor to the positions /// /// idx is the offset of the current doc in the block. @@ -97,16 +96,12 @@ impl FreqHandler { let stop = self.positions_offsets[idx + 1]; &self.positions[start..stop] } - + /// Decompresses a complete frequency block pub fn read_freq_block<'a>(&mut self, data: &'a [u8]) -> &'a [u8] { match self.option { - SegmentPostingsOption::NoFreq => { - data - } - SegmentPostingsOption::Freq => { - self.freq_decoder.uncompress_block_unsorted(data) - } + SegmentPostingsOption::NoFreq => data, + SegmentPostingsOption::Freq => self.freq_decoder.uncompress_block_unsorted(data), SegmentPostingsOption::FreqAndPositions => { let remaining: &'a [u8] = self.freq_decoder.uncompress_block_unsorted(data); self.fill_positions_offset(); @@ -114,7 +109,7 @@ impl FreqHandler { } } } - + /// Decompresses an incomplete frequency block pub fn read_freq_vint(&mut self, data: &[u8], num_els: usize) { match self.option { @@ -128,5 +123,4 @@ impl FreqHandler { } } } - } \ No newline at end of file diff --git a/src/postings/intersection.rs b/src/postings/intersection.rs index 6946a652e..75699065c 100644 --- a/src/postings/intersection.rs +++ b/src/postings/intersection.rs @@ -7,7 +7,7 @@ use DocId; /// Creates a `DocSet` that iterator through the intersection of two `DocSet`s. pub struct IntersectionDocSet { docsets: Vec, - finished: bool, + finished: bool, doc: DocId, } @@ -18,11 +18,14 @@ impl From> for IntersectionDocSet { docsets: docsets, finished: false, doc: DocId::max_value(), - } + } } } impl IntersectionDocSet { + /// Returns an array to the underlying `DocSet`s of the intersection. + /// These `DocSet` are in the same position as the `IntersectionDocSet`, + /// so that user can access their `docfreq` and `positions`. pub fn docsets(&self) -> &[TDocSet] { &self.docsets[..] } @@ -30,8 +33,7 @@ impl IntersectionDocSet { impl DocSet for IntersectionDocSet { - - fn advance(&mut self,) -> bool { + fn advance(&mut self) -> bool { if self.finished { return false; } @@ -71,8 +73,8 @@ impl DocSet for IntersectionDocSet { } } } - - fn doc(&self,) -> DocId { + + fn doc(&self) -> DocId { self.doc } } diff --git a/src/postings/offset_postings.rs b/src/postings/offset_postings.rs index fe7ea453d..1410ef922 100644 --- a/src/postings/offset_postings.rs +++ b/src/postings/offset_postings.rs @@ -15,7 +15,6 @@ pub struct OffsetPostings<'a> { } impl<'a> OffsetPostings<'a> { - /// Constructor pub fn new(underlying: SegmentPostings<'a>, offset: DocId) -> OffsetPostings { OffsetPostings { @@ -26,38 +25,35 @@ impl<'a> OffsetPostings<'a> { } impl<'a> DocSet for OffsetPostings<'a> { - fn advance(&mut self,) -> bool { + fn advance(&mut self) -> bool { self.underlying.advance() } - - fn doc(&self,) -> DocId { + + fn doc(&self) -> DocId { self.underlying.doc() + self.offset } - + fn skip_next(&mut self, target: DocId) -> SkipResult { if target >= self.offset { SkipResult::OverStep - } - else { - self.underlying.skip_next(target - self.offset) + } else { + self.underlying.skip_next(target - self.offset) } } } impl<'a> HasLen for OffsetPostings<'a> { - fn len(&self,) -> usize { + fn len(&self) -> usize { self.underlying.len() } } impl<'a> Postings for OffsetPostings<'a> { - - fn term_freq(&self,) -> u32 { + fn term_freq(&self) -> u32 { self.underlying.term_freq() } - + fn positions(&self) -> &[u32] { self.underlying.positions() } - } \ No newline at end of file diff --git a/src/postings/postings.rs b/src/postings/postings.rs index dbd83a997..52f16198a 100644 --- a/src/postings/postings.rs +++ b/src/postings/postings.rs @@ -7,45 +7,38 @@ use postings::docset::DocSet; /// containing the term. Optionally, for each document, /// it may also give access to the term frequency /// as well as the list of term positions. -/// +/// /// Its main implementation is `SegmentPostings`, /// but other implementations mocking `SegmentPostings` exist, /// for merging segments or for testing. pub trait Postings: DocSet { /// Returns the term frequency - fn term_freq(&self,) -> u32; + fn term_freq(&self) -> u32; /// Returns the list of positions of the term, expressed as a list of /// token ordinals. fn positions(&self) -> &[u32]; } impl Postings for Box { - - fn term_freq(&self,) -> u32 { + fn term_freq(&self) -> u32 { let unboxed: &TPostings = self.borrow(); unboxed.term_freq() } - + fn positions(&self) -> &[u32] { let unboxed: &TPostings = self.borrow(); unboxed.positions() } - } impl<'a, TPostings: Postings> Postings for &'a mut TPostings { - - fn term_freq(&self,) -> u32 { + fn term_freq(&self) -> u32 { let unref: &TPostings = *self; unref.term_freq() } - + fn positions(&self) -> &[u32] { let unref: &TPostings = *self; unref.positions() } - } - - - diff --git a/src/postings/postings_writer.rs b/src/postings/postings_writer.rs index 3b6ddd440..c3d1f997f 100644 --- a/src/postings/postings_writer.rs +++ b/src/postings/postings_writer.rs @@ -9,12 +9,11 @@ use schema::Field; use analyzer::StreamingIterator; use datastruct::stacker::{HashMap, Heap}; -/// The `PostingsWriter` is in charge of receiving documenting +/// The `PostingsWriter` is in charge of receiving documenting /// and building a `Segment` in anonymous memory. /// /// `PostingsWriter` writes in a `Heap`. pub trait PostingsWriter { - /// Record that a document contains a term at a given position. /// /// * doc - the document id @@ -22,17 +21,22 @@ pub trait PostingsWriter { /// * term - the term /// * heap - heap used to store the postings informations as well as the terms /// in the hashmap. - fn suscribe(&mut self, doc: DocId, pos: u32, term: &Term, heap: &Heap); - + fn suscribe(&mut self, doc: DocId, pos: u32, term: &Term, heap: &Heap); + /// Serializes the postings on disk. /// The actual serialization format is handled by the `PostingsSerializer`. fn serialize(&self, serializer: &mut PostingsSerializer, heap: &Heap) -> io::Result<()>; - + /// Closes all of the currently open `Recorder`'s. fn close(&mut self, heap: &Heap); - + /// Tokenize a text and suscribe all of its token. - fn index_text<'a>(&mut self, doc_id: DocId, field: Field, field_values: &[&'a FieldValue], heap: &Heap) -> u32 { + fn index_text<'a>(&mut self, + doc_id: DocId, + field: Field, + field_values: &[&'a FieldValue], + heap: &Heap) + -> u32 { let mut pos = 0u32; let mut num_tokens: u32 = 0u32; let mut term = Term::allocate(field, 100); @@ -65,7 +69,7 @@ fn hashmap_size_in_bits(heap_capacity: u32) -> usize { let num_buckets_usable = heap_capacity / 100; let hash_table_size = num_buckets_usable * 2; let mut pow = 512; - for num_bits in 10 .. 32 { + for num_bits in 10..32 { pow <<= 1; if pow > hash_table_size { return num_bits; @@ -75,31 +79,26 @@ fn hashmap_size_in_bits(heap_capacity: u32) -> usize { } impl<'a, Rec: Recorder + 'static> SpecializedPostingsWriter<'a, Rec> { - /// constructor pub fn new(heap: &'a Heap) -> SpecializedPostingsWriter<'a, Rec> { let capacity = heap.capacity(); let hashmap_size = hashmap_size_in_bits(capacity); - SpecializedPostingsWriter { - term_index: HashMap::new(hashmap_size, heap), - } + SpecializedPostingsWriter { term_index: HashMap::new(hashmap_size, heap) } } - + /// Builds a `SpecializedPostingsWriter` storing its data in a heap. pub fn new_boxed(heap: &'a Heap) -> Box { Box::new(SpecializedPostingsWriter::::new(heap)) - } - + } } impl<'a, Rec: Recorder + 'static> PostingsWriter for SpecializedPostingsWriter<'a, Rec> { - fn close(&mut self, heap: &Heap) { for recorder in self.term_index.values_mut() { recorder.close_doc(heap); } } - + #[inline] fn suscribe(&mut self, doc: DocId, position: u32, term: &Term, heap: &Heap) { let mut recorder = self.term_index.get_or_create(term); @@ -112,9 +111,9 @@ impl<'a, Rec: Recorder + 'static> PostingsWriter for SpecializedPostingsWriter<' } recorder.record_position(position, heap); } - + fn serialize(&self, serializer: &mut PostingsSerializer, heap: &Heap) -> io::Result<()> { - let mut term_offsets: Vec<(&[u8], (u32, &Rec))> = self.term_index + let mut term_offsets: Vec<(&[u8], (u32, &Rec))> = self.term_index .iter() .collect(); term_offsets.sort_by_key(|&(k, _v)| k); @@ -128,8 +127,6 @@ impl<'a, Rec: Recorder + 'static> PostingsWriter for SpecializedPostingsWriter<' } Ok(()) } - - } diff --git a/src/postings/recorder.rs b/src/postings/recorder.rs index 095102a3d..94173720b 100644 --- a/src/postings/recorder.rs +++ b/src/postings/recorder.rs @@ -4,32 +4,36 @@ use postings::PostingsSerializer; use datastruct::stacker::{ExpUnrolledLinkedList, Heap, HeapAllocable}; const EMPTY_ARRAY: [u32; 0] = [0u32; 0]; -const POSITION_END: u32 = 4294967295; +const POSITION_END: u32 = 4294967295; /// Recorder is in charge of recording relevant information about /// the presence of a term in a document. /// -/// Depending on the `TextIndexingOptions` associated to the +/// Depending on the `TextIndexingOptions` associated to the /// field, the recorder may records /// * the document frequency -/// * the document id +/// * the document id /// * the term frequency /// * the term positions pub trait Recorder: HeapAllocable { /// Returns the current document - fn current_doc(&self,) -> u32; + fn current_doc(&self) -> u32; /// Starts recording information about a new document - /// This method shall only be called if the term is within the document. + /// This method shall only be called if the term is within the document. fn new_doc(&mut self, doc: DocId, heap: &Heap); - /// Record the position of a term. For each document, + /// Record the position of a term. For each document, /// this method will be called `term_freq` times. fn record_position(&mut self, position: u32, heap: &Heap); - /// Close the document. It will help record the term frequency. + /// Close the document. It will help record the term frequency. fn close_doc(&mut self, heap: &Heap); /// Returns the number of document that have been seen so far - fn doc_freq(&self,) -> u32; + fn doc_freq(&self) -> u32; /// Pushes the postings information to the serializer. - fn serialize(&self, self_addr: u32, serializer: &mut PostingsSerializer, heap: &Heap) -> io::Result<()>; + fn serialize(&self, + self_addr: u32, + serializer: &mut PostingsSerializer, + heap: &Heap) + -> io::Result<()>; } /// Only records the doc ids @@ -51,11 +55,10 @@ impl HeapAllocable for NothingRecorder { } impl Recorder for NothingRecorder { - - fn current_doc(&self,) -> DocId { + fn current_doc(&self) -> DocId { self.current_doc } - + fn new_doc(&mut self, doc: DocId, heap: &Heap) { self.current_doc = doc; self.stack.push(doc, heap); @@ -66,17 +69,20 @@ impl Recorder for NothingRecorder { fn close_doc(&mut self, _heap: &Heap) {} - fn doc_freq(&self,) -> u32 { + fn doc_freq(&self) -> u32 { self.doc_freq } - - fn serialize(&self, self_addr: u32, serializer: &mut PostingsSerializer, heap: &Heap) -> io::Result<()> { + + fn serialize(&self, + self_addr: u32, + serializer: &mut PostingsSerializer, + heap: &Heap) + -> io::Result<()> { for doc in self.stack.iter(self_addr, heap) { try!(serializer.write_doc(doc, 0u32, &EMPTY_ARRAY)); } Ok(()) } - } /// Recorder encoding document ids, and term frequencies @@ -94,16 +100,13 @@ impl HeapAllocable for TermFrequencyRecorder { stack: ExpUnrolledLinkedList::with_addr(addr), current_doc: u32::max_value(), current_tf: 0u32, - doc_freq: 0u32 - } + doc_freq: 0u32, + } } } impl Recorder for TermFrequencyRecorder { - - - - fn current_doc(&self,) -> DocId { + fn current_doc(&self) -> DocId { self.current_doc } @@ -112,22 +115,26 @@ impl Recorder for TermFrequencyRecorder { self.current_doc = doc; self.stack.push(doc, heap); } - + fn record_position(&mut self, _position: u32, _heap: &Heap) { self.current_tf += 1; } - + fn close_doc(&mut self, heap: &Heap) { debug_assert!(self.current_tf > 0); self.stack.push(self.current_tf, heap); self.current_tf = 0; } - - fn doc_freq(&self,) -> u32 { + + fn doc_freq(&self) -> u32 { self.doc_freq } - - fn serialize(&self, self_addr:u32, serializer: &mut PostingsSerializer, heap: &Heap) -> io::Result<()> { + + fn serialize(&self, + self_addr: u32, + serializer: &mut PostingsSerializer, + heap: &Heap) + -> io::Result<()> { let mut doc_iter = self.stack.iter(self_addr, heap); loop { if let Some(doc) = doc_iter.next() { @@ -140,7 +147,6 @@ impl Recorder for TermFrequencyRecorder { } Ok(()) } - } /// Recorder encoding term frequencies as well as positions. @@ -162,12 +168,10 @@ impl HeapAllocable for TFAndPositionRecorder { } impl Recorder for TFAndPositionRecorder { - - - fn current_doc(&self,) -> DocId { + fn current_doc(&self) -> DocId { self.current_doc } - + fn new_doc(&mut self, doc: DocId, heap: &Heap) { self.doc_freq += 1; self.current_doc = doc; @@ -177,16 +181,20 @@ impl Recorder for TFAndPositionRecorder { fn record_position(&mut self, position: u32, heap: &Heap) { self.stack.push(position, heap); } - + fn close_doc(&mut self, heap: &Heap) { self.stack.push(POSITION_END, heap); } - - fn doc_freq(&self,) -> u32 { + + fn doc_freq(&self) -> u32 { self.doc_freq } - - fn serialize(&self, self_addr: u32, serializer: &mut PostingsSerializer, heap: &Heap) -> io::Result<()> { + + fn serialize(&self, + self_addr: u32, + serializer: &mut PostingsSerializer, + heap: &Heap) + -> io::Result<()> { let mut doc_positions = Vec::with_capacity(100); let mut positions_iter = self.stack.iter(self_addr, heap); while let Some(doc) = positions_iter.next() { @@ -197,8 +205,7 @@ impl Recorder for TFAndPositionRecorder { Some(position) => { if position == POSITION_END { break; - } - else { + } else { doc_positions.push(position - prev_position); prev_position = position; } @@ -212,7 +219,4 @@ impl Recorder for TFAndPositionRecorder { } Ok(()) } - } - - diff --git a/src/postings/segment_postings.rs b/src/postings/segment_postings.rs index cac9b86c8..0bb8af8e3 100644 --- a/src/postings/segment_postings.rs +++ b/src/postings/segment_postings.rs @@ -6,9 +6,9 @@ use std::num::Wrapping; const EMPTY_DATA: [u8; 0] = [0u8; 0]; -/// `SegmentPostings` represents the inverted list or postings associated to +/// `SegmentPostings` represents the inverted list or postings associated to /// a term in a `Segment`. -/// +/// /// As we iterate through the `SegmentPostings`, the frequencies are optionally decoded. /// Positions on the other hand, are optionally entirely decoded upfront. pub struct SegmentPostings<'a> { @@ -16,21 +16,21 @@ pub struct SegmentPostings<'a> { doc_offset: u32, block_decoder: SIMDBlockDecoder, freq_handler: FreqHandler, - remaining_data: &'a[u8], + remaining_data: &'a [u8], cur: Wrapping, } impl<'a> SegmentPostings<'a> { - - fn load_next_block(&mut self,) { + fn load_next_block(&mut self) { let num_remaining_docs = self.len - self.cur.0; if num_remaining_docs >= NUM_DOCS_PER_BLOCK { - self.remaining_data = self.block_decoder.uncompress_block_sorted(self.remaining_data, self.doc_offset); + self.remaining_data = self.block_decoder + .uncompress_block_sorted(self.remaining_data, self.doc_offset); self.remaining_data = self.freq_handler.read_freq_block(self.remaining_data); self.doc_offset = self.block_decoder.output(NUM_DOCS_PER_BLOCK - 1); - } - else { - self.remaining_data = self.block_decoder.uncompress_vint_sorted(self.remaining_data, self.doc_offset, num_remaining_docs); + } else { + self.remaining_data = self.block_decoder + .uncompress_vint_sorted(self.remaining_data, self.doc_offset, num_remaining_docs); self.freq_handler.read_freq_vint(self.remaining_data, num_remaining_docs); } } @@ -39,7 +39,7 @@ impl<'a> SegmentPostings<'a> { /// /// * `len` - number of document in the posting lists. /// * `data` - data array. The complete data is not necessarily used. - /// * `freq_handler` - the freq handler is in charge of decoding + /// * `freq_handler` - the freq handler is in charge of decoding /// frequencies and/or positions pub fn from_data(len: u32, data: &'a [u8], freq_handler: FreqHandler) -> SegmentPostings<'a> { SegmentPostings { @@ -51,7 +51,7 @@ impl<'a> SegmentPostings<'a> { cur: Wrapping(usize::max_value()), } } - + /// Returns an empty segment postings object pub fn empty() -> SegmentPostings<'static> { SegmentPostings { @@ -65,11 +65,10 @@ impl<'a> SegmentPostings<'a> { } /// Index within a block is used as an address when - /// interacting with the `FreqHandler` - fn index_within_block(&self,) -> usize { + /// interacting with the `FreqHandler` + fn index_within_block(&self) -> usize { self.cur.0 % NUM_DOCS_PER_BLOCK } - } @@ -77,7 +76,7 @@ impl<'a> DocSet for SegmentPostings<'a> { // goes to the next element. // next needs to be called a first time to point to the correct element. #[inline] - fn advance(&mut self,) -> bool { + fn advance(&mut self) -> bool { self.cur += Wrapping(1); if self.cur.0 >= self.len { return false; @@ -87,27 +86,25 @@ impl<'a> DocSet for SegmentPostings<'a> { } true } - + #[inline] - fn doc(&self,) -> DocId { + fn doc(&self) -> DocId { self.block_decoder.output(self.index_within_block()) } - } impl<'a> HasLen for SegmentPostings<'a> { - fn len(&self,) -> usize { + fn len(&self) -> usize { self.len } } impl<'a> Postings for SegmentPostings<'a> { - fn term_freq(&self,) -> u32 { + fn term_freq(&self) -> u32 { self.freq_handler.freq(self.index_within_block()) } - + fn positions(&self) -> &[u32] { self.freq_handler.positions(self.index_within_block()) } } - diff --git a/src/postings/segment_postings_option.rs b/src/postings/segment_postings_option.rs index 082ea0660..cf2f8b936 100644 --- a/src/postings/segment_postings_option.rs +++ b/src/postings/segment_postings_option.rs @@ -2,7 +2,7 @@ /// Object describing the amount of information required when reading a postings. /// -/// Since decoding information is not free, this makes it possible to +/// Since decoding information is not free, this makes it possible to /// avoid this extra cost when the information is not required. /// For instance, positions are useful when running phrase queries /// but useless in other queries. @@ -14,4 +14,4 @@ pub enum SegmentPostingsOption { Freq, /// DocIds, term frequencies and positions will be decoded. FreqAndPositions, -} \ No newline at end of file +} diff --git a/src/postings/serializer.rs b/src/postings/serializer.rs index 673a9059d..3316d1f5e 100644 --- a/src/postings/serializer.rs +++ b/src/postings/serializer.rs @@ -19,14 +19,14 @@ use common::BinarySerializable; /// `PostingsSerializer` is in charge of serializing -/// postings on disk, in the +/// postings on disk, in the /// * `.idx` (inverted index) /// * `.pos` (positions file) /// * `.term` (term dictionary) -/// -/// `PostingsWriter` are in charge of pushing the data to the +/// +/// `PostingsWriter` are in charge of pushing the data to the /// serializer. -/// +/// /// The serializer expects to receive the following calls /// in this order : /// @@ -45,10 +45,10 @@ use common::BinarySerializable; /// Terms have to be pushed in a lexicographically-sorted order. /// Within a term, document have to be pushed in increasing order. /// -/// A description of the serialization format is -/// [available here](https://fulmicoton.gitbooks.io/tantivy-doc/content/inverted-index.html). +/// A description of the serialization format is +/// [available here](https://fulmicoton.gitbooks.io/tantivy-doc/content/inverted-index.html). pub struct PostingsSerializer { - terms_fst_builder: FstMapBuilder, // TODO find an alternative to work around the "move" + terms_fst_builder: FstMapBuilder, /* TODO find an alternative to work around the "move" */ postings_write: WritePtr, positions_write: WritePtr, written_bytes_postings: usize, @@ -65,14 +65,12 @@ pub struct PostingsSerializer { } impl PostingsSerializer { - - /// Open a new `PostingsSerializer` for the given segment - pub fn new( - terms_write: WritePtr, - postings_write: WritePtr, - positions_write: WritePtr, - schema: Schema - ) -> Result { + /// Open a new `PostingsSerializer` for the given segment + pub fn new(terms_write: WritePtr, + postings_write: WritePtr, + positions_write: WritePtr, + schema: Schema) + -> Result { let terms_fst_builder = try!(FstMapBuilder::new(terms_write)); Ok(PostingsSerializer { terms_fst_builder: terms_fst_builder, @@ -91,41 +89,36 @@ impl PostingsSerializer { term_open: false, }) } - - - /// Open a new `PostingsSerializer` for the given segment + + + /// Open a new `PostingsSerializer` for the given segment pub fn open(segment: &mut Segment) -> Result { let terms_write = try!(segment.open_write(SegmentComponent::TERMS)); let postings_write = try!(segment.open_write(SegmentComponent::POSTINGS)); let positions_write = try!(segment.open_write(SegmentComponent::POSITIONS)); - PostingsSerializer::new( - terms_write, - postings_write, - positions_write, - segment.schema() - ) + PostingsSerializer::new(terms_write, + postings_write, + positions_write, + segment.schema()) } - + fn load_indexing_options(&mut self, field: Field) { let field_entry: &FieldEntry = self.schema.get_field_entry(field); self.text_indexing_options = match *field_entry.field_type() { - FieldType::Str(ref text_options) => { - text_options.get_indexing_options() - } + FieldType::Str(ref text_options) => text_options.get_indexing_options(), FieldType::U32(ref u32_options) => { if u32_options.is_indexed() { TextIndexingOptions::Unindexed - } - else { - TextIndexingOptions::Untokenized + } else { + TextIndexingOptions::Untokenized } } }; } - + /// Starts the postings for a new term. /// * term - the term. It needs to come after the previous term according - /// to the lexicographical order. + /// to the lexicographical order. /// * doc_freq - return the number of document containing the term. pub fn new_term(&mut self, term: &Term, doc_freq: DocId) -> io::Result<()> { if self.term_open { @@ -145,31 +138,34 @@ impl PostingsSerializer { self.terms_fst_builder .insert(term.as_slice(), &term_info) } - + /// Finish the serialization for this term postings. /// /// If the current block is incomplete, it need to be encoded - /// using `VInt` encoding. - pub fn close_term(&mut self,) -> io::Result<()> { + /// using `VInt` encoding. + pub fn close_term(&mut self) -> io::Result<()> { if self.term_open { if !self.doc_ids.is_empty() { // we have doc ids waiting to be written - // this happens when the number of doc ids is + // this happens when the number of doc ids is // not a perfect multiple of our block size. // // In that case, the remaining part is encoded // using variable int encoding. { - let block_encoded = self.block_encoder.compress_vint_sorted(&self.doc_ids, self.last_doc_id_encoded); + let block_encoded = self.block_encoder + .compress_vint_sorted(&self.doc_ids, self.last_doc_id_encoded); self.written_bytes_postings += block_encoded.len(); try!(self.postings_write.write_all(block_encoded)); self.doc_ids.clear(); } - // ... Idem for term frequencies + // ... Idem for term frequencies if self.text_indexing_options.is_termfreq_enabled() { - let block_encoded = self.block_encoder.compress_vint_unsorted(&self.term_freqs[..]); + let block_encoded = self.block_encoder + .compress_vint_unsorted(&self.term_freqs[..]); for num in block_encoded { - self.written_bytes_postings += try!(num.serialize(&mut self.postings_write)); + self.written_bytes_postings += + try!(num.serialize(&mut self.postings_write)); } self.term_freqs.clear(); } @@ -177,8 +173,10 @@ impl PostingsSerializer { // On the other hand, positions are entirely buffered until the // end of the term, at which point they are compressed and written. if self.text_indexing_options.is_position_enabled() { - self.written_bytes_positions += try!(VInt(self.position_deltas.len() as u64).serialize(&mut self.positions_write)); - let positions_encoded: &[u8] = self.positions_encoder.compress_unsorted(&self.position_deltas[..]); + self.written_bytes_positions += try!(VInt(self.position_deltas.len() as u64) + .serialize(&mut self.positions_write)); + let positions_encoded: &[u8] = self.positions_encoder + .compress_unsorted(&self.position_deltas[..]); try!(self.positions_write.write_all(positions_encoded)); self.written_bytes_positions += positions_encoded.len(); self.position_deltas.clear(); @@ -187,8 +185,8 @@ impl PostingsSerializer { } Ok(()) } - - + + /// Serialize the information that a document contains the current term, /// its term frequency, and the position deltas. /// @@ -198,7 +196,11 @@ impl PostingsSerializer { /// /// Term frequencies and positions may be ignored by the serializer depending /// on the configuration of the field in the `Schema`. - pub fn write_doc(&mut self, doc_id: DocId, term_freq: u32, position_deltas: &[u32]) -> io::Result<()> { + pub fn write_doc(&mut self, + doc_id: DocId, + term_freq: u32, + position_deltas: &[u32]) + -> io::Result<()> { self.doc_ids.push(doc_id); if self.text_indexing_options.is_termfreq_enabled() { self.term_freqs.push(term_freq as u32); @@ -209,14 +211,16 @@ impl PostingsSerializer { if self.doc_ids.len() == NUM_DOCS_PER_BLOCK { { // encode the doc ids - let block_encoded: &[u8] = self.block_encoder.compress_block_sorted(&self.doc_ids, self.last_doc_id_encoded); + let block_encoded: &[u8] = self.block_encoder + .compress_block_sorted(&self.doc_ids, self.last_doc_id_encoded); self.last_doc_id_encoded = self.doc_ids[self.doc_ids.len() - 1]; try!(self.postings_write.write_all(block_encoded)); self.written_bytes_postings += block_encoded.len(); } if self.text_indexing_options.is_termfreq_enabled() { // encode the term_freqs - let block_encoded: &[u8] = self.block_encoder.compress_block_unsorted(&self.term_freqs); + let block_encoded: &[u8] = self.block_encoder + .compress_block_unsorted(&self.term_freqs); try!(self.postings_write.write_all(block_encoded)); self.written_bytes_postings += block_encoded.len(); self.term_freqs.clear(); @@ -225,9 +229,9 @@ impl PostingsSerializer { } Ok(()) } - + /// Closes the serializer. - pub fn close(mut self,) -> io::Result<()> { + pub fn close(mut self) -> io::Result<()> { try!(self.close_term()); try!(self.terms_fst_builder.finish()); try!(self.postings_write.flush()); diff --git a/src/postings/vec_postings.rs b/src/postings/vec_postings.rs index 8704915e0..399307cff 100644 --- a/src/postings/vec_postings.rs +++ b/src/postings/vec_postings.rs @@ -4,7 +4,7 @@ use DocId; use postings::{Postings, DocSet, HasLen}; use std::num::Wrapping; -const EMPTY_ARRAY: [u32; 0] = []; +const EMPTY_ARRAY: [u32; 0] = []; /// Simulate a `Postings` objects from a `VecPostings`. /// `VecPostings` only exist for testing purposes. @@ -26,43 +26,43 @@ impl From> for VecPostings { } impl DocSet for VecPostings { - fn advance(&mut self,) -> bool { + fn advance(&mut self) -> bool { self.cursor += Wrapping(1); self.doc_ids.len() > self.cursor.0 } - - fn doc(&self,) -> DocId { + + fn doc(&self) -> DocId { self.doc_ids[self.cursor.0] } } impl HasLen for VecPostings { - fn len(&self,) -> usize { + fn len(&self) -> usize { self.doc_ids.len() } } impl Postings for VecPostings { - fn term_freq(&self,) -> u32 { + fn term_freq(&self) -> u32 { 1u32 } - + fn positions(&self) -> &[u32] { &EMPTY_ARRAY - } + } } #[cfg(test)] pub mod tests { - + use super::*; - use DocId; - use postings::{Postings, SkipResult, DocSet}; - - + use DocId; + use postings::{Postings, SkipResult, DocSet}; + + #[test] pub fn test_vec_postings() { - let doc_ids: Vec = (0u32..1024u32).map(|e| e*3).collect(); + let doc_ids: Vec = (0u32..1024u32).map(|e| e * 3).collect(); let mut postings = VecPostings::from(doc_ids); assert!(postings.advance()); assert_eq!(postings.doc(), 0u32); @@ -77,4 +77,3 @@ pub mod tests { } } - diff --git a/src/query/boolean_query/boolean_clause.rs b/src/query/boolean_query/boolean_clause.rs index 34f49f0b7..e2e2a55b6 100644 --- a/src/query/boolean_query/boolean_clause.rs +++ b/src/query/boolean_query/boolean_clause.rs @@ -12,7 +12,7 @@ impl BooleanClause { pub fn new(query: Box, occur: Occur) -> BooleanClause { BooleanClause { query: query, - occur: occur + occur: occur, } - } + } } \ No newline at end of file diff --git a/src/query/boolean_query/boolean_query.rs b/src/query/boolean_query/boolean_query.rs index abecdb148..195bbf20c 100644 --- a/src/query/boolean_query/boolean_query.rs +++ b/src/query/boolean_query/boolean_query.rs @@ -12,11 +12,11 @@ use query::OccurFilter; /// /// The documents matched by the boolean query are /// those which -/// * match all of the sub queries associated with the +/// * match all of the sub queries associated with the /// `Must` occurence -/// * match none of the sub queries associated with the +/// * match none of the sub queries associated with the /// `MustNot` occurence. -/// * match at least one of the subqueries that is not +/// * match at least one of the subqueries that is not /// a `MustNot` occurence. #[derive(Debug)] pub struct BooleanQuery { @@ -25,14 +25,11 @@ pub struct BooleanQuery { impl From> for BooleanQuery { fn from(clauses: Vec) -> BooleanQuery { - BooleanQuery { - clauses: clauses, - } - } + BooleanQuery { clauses: clauses } + } } impl Query for BooleanQuery { - fn as_any(&self) -> &Any { self } @@ -41,8 +38,7 @@ impl Query for BooleanQuery { let sub_weights = try!(self.clauses .iter() .map(|clause| clause.query.weight(searcher)) - .collect() - ); + .collect()); let occurs: Vec = self.clauses .iter() .map(|clause| clause.occur) @@ -50,5 +46,4 @@ impl Query for BooleanQuery { let filter = OccurFilter::new(&occurs); Ok(box BooleanWeight::new(sub_weights, filter)) } - } \ No newline at end of file diff --git a/src/query/boolean_query/boolean_scorer.rs b/src/query/boolean_query/boolean_scorer.rs index e6c17af69..c24f67760 100644 --- a/src/query/boolean_query/boolean_scorer.rs +++ b/src/query/boolean_query/boolean_scorer.rs @@ -33,7 +33,7 @@ impl Ord for HeapItem { } pub struct BooleanScorer { - postings: Vec, + scorers: Vec, queue: BinaryHeap, doc: DocId, score_combiner: ScoreCombiner, @@ -43,20 +43,20 @@ pub struct BooleanScorer { impl BooleanScorer { pub fn scorers(&self) -> &[TScorer] { - &self.postings + &self.scorers } - pub fn new(postings: Vec, + pub fn new(scorers: Vec, occur_filter: OccurFilter) -> BooleanScorer { - let score_combiner = ScoreCombiner::default_for_num_scorers(postings.len()); - let mut non_empty_postings: Vec = Vec::new(); - for mut posting in postings { + let score_combiner = ScoreCombiner::default_for_num_scorers(scorers.len()); + let mut non_empty_scorers: Vec = Vec::new(); + for mut posting in scorers { let non_empty = posting.advance(); if non_empty { - non_empty_postings.push(posting); + non_empty_scorers.push(posting); } } - let heap_items: Vec = non_empty_postings + let heap_items: Vec = non_empty_scorers .iter() .map(|posting| posting.doc()) .enumerate() @@ -68,7 +68,7 @@ impl BooleanScorer { }) .collect(); BooleanScorer { - postings: non_empty_postings, + scorers: non_empty_scorers, queue: BinaryHeap::from(heap_items), doc: 0u32, score_combiner: score_combiner, @@ -77,7 +77,7 @@ impl BooleanScorer { } } - /// Advances the head of our heap (the segment postings with the lowest doc) + /// Advances the head of our heap (the segment posting with the lowest doc) /// It will also update the new current `DocId` as well as the term frequency /// associated with the segment postings. /// @@ -89,9 +89,9 @@ impl BooleanScorer { fn advance_head(&mut self,) { { let mut mutable_head = self.queue.peek_mut().unwrap(); - let cur_postings = &mut self.postings[mutable_head.ord as usize]; - if cur_postings.advance() { - mutable_head.doc = cur_postings.doc(); + let cur_scorers = &mut self.scorers[mutable_head.ord as usize]; + if cur_scorers.advance() { + mutable_head.doc = cur_scorers.doc(); return; } } @@ -108,7 +108,7 @@ impl DocSet for BooleanScorer { Some(heap_item) => { let ord = heap_item.ord as usize; self.doc = heap_item.doc; - let score = self.postings[ord].score(); + let score = self.scorers[ord].score(); self.score_combiner.update(score); ord_bitset |= 1 << ord; } @@ -120,7 +120,7 @@ impl DocSet for BooleanScorer { while let Some(&HeapItem {doc, ord}) = self.queue.peek() { if doc == self.doc { let ord = ord as usize; - let score = self.postings[ord].score(); + let score = self.scorers[ord].score(); self.score_combiner.update(score); ord_bitset |= 1 << ord; } diff --git a/src/query/boolean_query/boolean_weight.rs b/src/query/boolean_query/boolean_weight.rs index 930b47348..830f85edf 100644 --- a/src/query/boolean_query/boolean_weight.rs +++ b/src/query/boolean_query/boolean_weight.rs @@ -11,8 +11,7 @@ pub struct BooleanWeight { } impl BooleanWeight { - pub fn new(weights: Vec>, - occur_filter: OccurFilter) -> BooleanWeight { + pub fn new(weights: Vec>, occur_filter: OccurFilter) -> BooleanWeight { BooleanWeight { weights: weights, occur_filter: occur_filter, @@ -22,15 +21,12 @@ impl BooleanWeight { impl Weight for BooleanWeight { - fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result> { - let sub_scorers: Vec> = try!( - self.weights - .iter() - .map(|weight| weight.scorer(reader)) - .collect() - ); - let boolean_scorer = BooleanScorer::new(sub_scorers, self.occur_filter); + let sub_scorers: Vec> = try!(self.weights + .iter() + .map(|weight| weight.scorer(reader)) + .collect()); + let boolean_scorer = BooleanScorer::new(sub_scorers, self.occur_filter); Ok(box boolean_scorer) } } diff --git a/src/query/multi_term_query/multi_term_query.rs b/src/query/multi_term_query/multi_term_query.rs index c5e410835..7b207d64f 100644 --- a/src/query/multi_term_query/multi_term_query.rs +++ b/src/query/multi_term_query/multi_term_query.rs @@ -12,20 +12,21 @@ use postings::SegmentPostingsOption; /// Query involving one or more terms. - #[derive(Eq, Clone, PartialEq, Debug)] -pub struct MultiTermQuery { - // TODO need a better Debug - occur_terms: Vec<(Occur, Term)> +pub struct MultiTermQuery { + // TODO need a better Debug + occur_terms: Vec<(Occur, Term)>, } impl MultiTermQuery { - /// Accessor for the number of terms - pub fn num_terms(&self,) -> usize { + pub fn num_terms(&self) -> usize { self.occur_terms.len() } + /// Same as `weight()`, except that rather than a boxed trait, + /// `specialized_weight` returns a specific type of the weight, allowing for + /// compile-time optimization. pub fn specialized_weight(&self, searcher: &Searcher) -> MultiTermWeight { let term_queries: Vec = self.occur_terms .iter() @@ -33,7 +34,7 @@ impl MultiTermQuery { .collect(); let occurs: Vec = self.occur_terms .iter() - .map(|&(occur, _) | occur.clone()) + .map(|&(occur, _)| occur.clone()) .collect(); let occur_filter = OccurFilter::new(&occurs); let weights = term_queries.iter() @@ -43,21 +44,17 @@ impl MultiTermQuery { term_weight }) .collect(); - MultiTermWeight { - weights: weights, - occur_filter: occur_filter, - } + MultiTermWeight::new(weights, occur_filter) } } impl Query for MultiTermQuery { - fn as_any(&self) -> &Any { self } - + fn weight(&self, searcher: &Searcher) -> Result> { Ok(box self.specialized_weight(searcher)) } @@ -66,16 +63,13 @@ impl Query for MultiTermQuery { impl From> for MultiTermQuery { fn from(occur_terms: Vec<(Occur, Term)>) -> MultiTermQuery { - MultiTermQuery { - occur_terms: occur_terms - } + MultiTermQuery { occur_terms: occur_terms } } } impl From> for MultiTermQuery { fn from(terms: Vec) -> MultiTermQuery { - let should_terms: Vec<(Occur, Term)> = terms - .into_iter() + let should_terms: Vec<(Occur, Term)> = terms.into_iter() .map(|term| (Occur::Should, term)) .collect(); MultiTermQuery::from(should_terms) diff --git a/src/query/multi_term_query/multi_term_weight.rs b/src/query/multi_term_query/multi_term_weight.rs index 6e12cd7a8..17e58d877 100644 --- a/src/query/multi_term_query/multi_term_weight.rs +++ b/src/query/multi_term_query/multi_term_weight.rs @@ -7,14 +7,28 @@ use postings::SegmentPostings; use query::term_query::{TermWeight, TermScorer}; use query::boolean_query::BooleanScorer; +/// Weight object associated to a [`MultiTermQuery`](./struct.MultiTermQuery.html). pub struct MultiTermWeight { - pub weights: Vec, - pub occur_filter: OccurFilter, + weights: Vec, + occur_filter: OccurFilter, } impl MultiTermWeight { + /// MultiTermWeigh constructor. + /// The `OccurFilter` is tied with the weights order. + pub fn new(weights: Vec, occur_filter: OccurFilter) -> MultiTermWeight { + MultiTermWeight { + weights: weights, + occur_filter: occur_filter, + } + } - pub fn specialized_scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result>>> { + /// Same as `scorer()`, except that rather than a boxed trait, + /// `specialized_scorer` returns a specific type of the scorer, allowing for + /// compile-time optimization. + pub fn specialized_scorer<'a>(&'a self, + reader: &'a SegmentReader) + -> Result>>> { let mut term_scorers: Vec> = Vec::new(); for term_weight in &self.weights { let term_scorer = try!(term_weight.specialized_scorer(reader)); @@ -22,12 +36,10 @@ impl MultiTermWeight { } Ok(BooleanScorer::new(term_scorers, self.occur_filter)) } - } impl Weight for MultiTermWeight { - fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result> { - Ok(box try!(self.specialized_scorer(reader))) + Ok(box try!(self.specialized_scorer(reader))) } -} \ No newline at end of file +} diff --git a/src/query/phrase_query/phrase_query.rs b/src/query/phrase_query/phrase_query.rs index 816fd4dbc..d44bdceb0 100644 --- a/src/query/phrase_query/phrase_query.rs +++ b/src/query/phrase_query/phrase_query.rs @@ -7,6 +7,19 @@ use query::Weight; use Result; +/// `PhraseQuery` matches a specific sequence of word. +/// For instance the phrase query for `"part time"` will match +/// the sentence +/// +/// **Alan just got a part time job.** +/// +/// On the other hand it will not match the sentence. +/// +/// **This is my favorite part of the job.** +/// +/// Using a `PhraseQuery` on a field requires positions +/// to be indexed for this field. +/// #[derive(Debug)] pub struct PhraseQuery { phrase_terms: Vec, @@ -24,7 +37,7 @@ impl Query for PhraseQuery { /// Create the weight associated to a query. /// /// See [Weight](./trait.Weight.html). - fn weight(&self, searcher: &Searcher) -> Result> { + fn weight(&self, _searcher: &Searcher) -> Result> { Ok(box PhraseWeight::from(self.phrase_terms.clone())) } diff --git a/src/query/scorer.rs b/src/query/scorer.rs index 5f21a2ab3..ae8c33d66 100644 --- a/src/query/scorer.rs +++ b/src/query/scorer.rs @@ -37,7 +37,9 @@ impl<'a> Scorer for Box { } } - +/// EmptyScorer is a dummy Scorer in which no document matches. +/// +/// It is useful for tests and handling edge cases. pub struct EmptyScorer; impl DocSet for EmptyScorer { From f7c882f3daf6368f06665e18fc832399845ae273 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Fri, 4 Nov 2016 12:23:27 +0900 Subject: [PATCH 18/19] issue/50 Added use case for BooleanQuery --- src/query/boolean_query/boolean_clause.rs | 18 ----- src/query/boolean_query/boolean_query.rs | 17 +++-- src/query/boolean_query/mod.rs | 83 ++++++++++++++++++++++- 3 files changed, 89 insertions(+), 29 deletions(-) delete mode 100644 src/query/boolean_query/boolean_clause.rs diff --git a/src/query/boolean_query/boolean_clause.rs b/src/query/boolean_query/boolean_clause.rs deleted file mode 100644 index e2e2a55b6..000000000 --- a/src/query/boolean_query/boolean_clause.rs +++ /dev/null @@ -1,18 +0,0 @@ -use query::Occur; -use query::Query; - -#[derive(Debug)] -pub struct BooleanClause { - pub query: Box, - pub occur: Occur, -} - - -impl BooleanClause { - pub fn new(query: Box, occur: Occur) -> BooleanClause { - BooleanClause { - query: query, - occur: occur, - } - } -} \ No newline at end of file diff --git a/src/query/boolean_query/boolean_query.rs b/src/query/boolean_query/boolean_query.rs index 195bbf20c..2d660f7aa 100644 --- a/src/query/boolean_query/boolean_query.rs +++ b/src/query/boolean_query/boolean_query.rs @@ -1,7 +1,6 @@ use Result; use std::any::Any; use super::boolean_weight::BooleanWeight; -use super::BooleanClause; use query::Weight; use Searcher; use query::Query; @@ -20,12 +19,12 @@ use query::OccurFilter; /// a `MustNot` occurence. #[derive(Debug)] pub struct BooleanQuery { - clauses: Vec, + subqueries: Vec<(Occur, Box)> } -impl From> for BooleanQuery { - fn from(clauses: Vec) -> BooleanQuery { - BooleanQuery { clauses: clauses } +impl From)>> for BooleanQuery { + fn from(subqueries: Vec<(Occur, Box)>) -> BooleanQuery { + BooleanQuery { subqueries: subqueries } } } @@ -35,13 +34,13 @@ impl Query for BooleanQuery { } fn weight(&self, searcher: &Searcher) -> Result> { - let sub_weights = try!(self.clauses + let sub_weights = try!(self.subqueries .iter() - .map(|clause| clause.query.weight(searcher)) + .map(|&(ref _occur, ref subquery)| subquery.weight(searcher)) .collect()); - let occurs: Vec = self.clauses + let occurs: Vec = self.subqueries .iter() - .map(|clause| clause.occur) + .map(|&(ref occur, ref _subquery)| *occur) .collect(); let filter = OccurFilter::new(&occurs); Ok(box BooleanWeight::new(sub_weights, filter)) diff --git a/src/query/boolean_query/mod.rs b/src/query/boolean_query/mod.rs index 43d0e3336..36a113b7f 100644 --- a/src/query/boolean_query/mod.rs +++ b/src/query/boolean_query/mod.rs @@ -1,11 +1,9 @@ -mod boolean_clause; mod boolean_query; mod boolean_scorer; mod boolean_weight; mod score_combiner; pub use self::boolean_query::BooleanQuery; -pub use self::boolean_clause::BooleanClause; pub use self::boolean_scorer::BooleanScorer; pub use self::score_combiner::ScoreCombiner; @@ -20,12 +18,93 @@ mod tests { use query::OccurFilter; use query::term_query::TermScorer; use query::Occur; + use query::Query; + use query::TermQuery; + use collector::tests::TestCollector; + use Index; + use schema::*; use fastfield::{U32FastFieldReader}; fn abs_diff(left: f32, right: f32) -> f32 { (right - left).abs() } + + #[test] + pub fn test_boolean_query() { + let mut schema_builder = SchemaBuilder::default(); + let text_field = schema_builder.add_text_field("text", TEXT); + let schema = schema_builder.build(); + let index = Index::create_from_tempdir(schema).unwrap(); + { + // writing the segment + let mut index_writer = index.writer_with_num_threads(1, 40_000_000).unwrap(); + { + let doc = doc!(text_field => "a b c"); + index_writer.add_document(doc).unwrap(); + } + { + let doc = doc!(text_field => "a c"); + index_writer.add_document(doc).unwrap(); + } + { + let doc = doc!(text_field => "b c"); + index_writer.add_document(doc).unwrap(); + } + { + let doc = doc!(text_field => "a b c d"); + index_writer.add_document(doc).unwrap(); + } + { + let doc = doc!(text_field => "d"); + index_writer.add_document(doc).unwrap(); + } + assert!(index_writer.commit().is_ok()); + } + let make_term_query = |text: &str| { + let term_query = TermQuery::from(Term::from_field_text(text_field, text)); + let query: Box = box term_query; + query + }; + + + let matching_docs = |boolean_query: &Query| { + let searcher = index.searcher(); + let mut test_collector = TestCollector::default(); + boolean_query.search(&*searcher, &mut test_collector).unwrap(); + test_collector.docs() + }; + { + let boolean_query = BooleanQuery::from(vec![(Occur::Must, make_term_query("a")) ]); + assert_eq!(matching_docs(&boolean_query), vec!(0, 1, 3)); + } + { + let boolean_query = BooleanQuery::from(vec![(Occur::Should, make_term_query("a")) ]); + assert_eq!(matching_docs(&boolean_query), vec!(0, 1, 3)); + } + { + let boolean_query = BooleanQuery::from(vec![(Occur::Should, make_term_query("a")), (Occur::Should, make_term_query("b"))]); + assert_eq!(matching_docs(&boolean_query), vec!(0, 1, 2, 3)); + } + { + let boolean_query = BooleanQuery::from(vec![(Occur::Must, make_term_query("a")), (Occur::Should, make_term_query("b"))]); + assert_eq!(matching_docs(&boolean_query), vec!(0, 1, 3)); + } + { + let boolean_query = BooleanQuery::from(vec![(Occur::Must, make_term_query("a")), + (Occur::Should, make_term_query("b")), + (Occur::MustNot, make_term_query("d")), + ]); + assert_eq!(matching_docs(&boolean_query), vec!(0, 1)); + } + { + let boolean_query = BooleanQuery::from(vec![(Occur::MustNot, make_term_query("d")),]); + // TODO optimize this use case : only MustNot subqueries... no need + // to read any postings. + assert_eq!(matching_docs(&boolean_query), Vec::new()); + } + } + #[test] pub fn test_boolean_scorer() { let occurs = vec!(Occur::Should, Occur::Should); From 805dd85d3dcdd0c8d728ebcb58288db4a6bea5e7 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Fri, 4 Nov 2016 12:37:47 +0900 Subject: [PATCH 19/19] issue/50 NOBUG --- src/lib.rs | 66 ++++++++++++++--------------------- src/query/phrase_query/mod.rs | 16 ++++----- 2 files changed, 33 insertions(+), 49 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index e22386b7c..4a33b0213 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -59,11 +59,16 @@ mod macros { macro_rules! doc( ($($field:ident => $value:expr),*) => {{ - let mut document = Document::default(); - $( - document.add(FieldValue::new($field, $value.into())); - )* - document + #[allow(unused_mut)] // avoid emitting a warning for `doc!()` + { + let mut document = Document::default(); + $( + document.add(FieldValue::new($field, $value.into())); + )* + document + } + + }}; ); } @@ -173,18 +178,15 @@ mod tests { // writing the segment let mut index_writer = index.writer_with_num_threads(1, 40_000_000).unwrap(); { - let mut doc = Document::default(); - doc.add_text(text_field, "af b"); + let doc = doc!(text_field=>"af b"); index_writer.add_document(doc).unwrap(); } { - let mut doc = Document::default(); - doc.add_text(text_field, "a b c"); + let doc = doc!(text_field=>"a b c"); index_writer.add_document(doc).unwrap(); } { - let mut doc = Document::default(); - doc.add_text(text_field, "a b c d"); + let doc = doc!(text_field=>"a b c d"); index_writer.add_document(doc).unwrap(); } assert!(index_writer.commit().is_ok()); @@ -199,27 +201,22 @@ mod tests { let index = Index::create_in_ram(schema_builder.build()); let mut index_writer = index.writer_with_num_threads(1, 40_000_000).unwrap(); { - let mut doc = Document::default(); - doc.add_text(text_field, "a b c"); - index_writer.add_document(doc).unwrap(); + index_writer.add_document(doc!(text_field=>"a b c")).unwrap(); index_writer.commit().unwrap(); } { { - let mut doc = Document::default(); - doc.add_text(text_field, "a"); + let doc = doc!(text_field=>"a"); index_writer.add_document(doc).unwrap(); } { - let mut doc = Document::default(); - doc.add_text(text_field, "a a"); + let doc = doc!(text_field=>"a a"); index_writer.add_document(doc).unwrap(); } index_writer.commit().unwrap(); } { - let mut doc = Document::default(); - doc.add_text(text_field, "c"); + let doc = doc!(text_field=>"c"); index_writer.add_document(doc).unwrap(); index_writer.commit().unwrap(); } @@ -245,17 +242,15 @@ mod tests { { let mut index_writer = index.writer_with_num_threads(1, 40_000_000).unwrap(); { - let mut doc = Document::default(); - doc.add_text(text_field, "a b c"); + let doc = doc!(text_field=>"a b c"); index_writer.add_document(doc).unwrap(); } { - let doc = Document::default(); + let doc = doc!(); index_writer.add_document(doc).unwrap(); } { - let mut doc = Document::default(); - doc.add_text(text_field, "a b"); + let doc = doc!(text_field=>"a b"); index_writer.add_document(doc).unwrap(); } index_writer.commit().unwrap(); @@ -281,8 +276,7 @@ mod tests { // writing the segment let mut index_writer = index.writer_with_num_threads(1, 40_000_000).unwrap(); { - let mut doc = Document::default(); - doc.add_text(text_field, "af af af bc bc"); + let doc = doc!(text_field=>"af af af bc bc"); index_writer.add_document(doc).unwrap(); } index_writer.commit().unwrap(); @@ -310,18 +304,15 @@ mod tests { // writing the segment let mut index_writer = index.writer_with_num_threads(1, 40_000_000).unwrap(); { - let mut doc = Document::default(); - doc.add_text(text_field, "af af af b"); + let doc = doc!(text_field=>"af af af b"); index_writer.add_document(doc).unwrap(); } { - let mut doc = Document::default(); - doc.add_text(text_field, "a b c"); + let doc = doc!(text_field=>"a b c"); index_writer.add_document(doc).unwrap(); } { - let mut doc = Document::default(); - doc.add_text(text_field, "a b c d"); + let doc = doc!(text_field=>"a b c d"); index_writer.add_document(doc).unwrap(); } index_writer.commit().unwrap(); @@ -379,18 +370,15 @@ mod tests { // writing the segment let mut index_writer = index.writer_with_num_threads(1, 40_000_000).unwrap(); { - let mut doc = Document::default(); - doc.add_text(text_field, "af b"); + let doc = doc!(text_field=>"af b"); index_writer.add_document(doc).unwrap(); } { - let mut doc = Document::default(); - doc.add_text(text_field, "a b c"); + let doc = doc!(text_field=>"a b c"); index_writer.add_document(doc).unwrap(); } { - let mut doc = Document::default(); - doc.add_text(text_field, "a b c d"); + let doc = doc!(text_field=>"a b c d"); index_writer.add_document(doc).unwrap(); } index_writer.commit().unwrap(); diff --git a/src/query/phrase_query/mod.rs b/src/query/phrase_query/mod.rs index d0bb61cfa..0500b8257 100644 --- a/src/query/phrase_query/mod.rs +++ b/src/query/phrase_query/mod.rs @@ -13,6 +13,7 @@ mod tests { use super::*; use query::Query; use core::Index; + use schema::FieldValue; use schema::{Document, Term, SchemaBuilder, TEXT}; use collector::tests::TestCollector; @@ -26,28 +27,23 @@ mod tests { { let mut index_writer = index.writer_with_num_threads(1, 40_000_000).unwrap(); { // 0 - let mut doc = Document::default(); - doc.add_text(text_field, "b b b d c g c"); + let doc = doc!(text_field=>"b b b d c g c"); index_writer.add_document(doc).unwrap(); } { // 1 - let mut doc = Document::default(); - doc.add_text(text_field, "a b b d c g c"); + let doc = doc!(text_field=>"a b b d c g c"); index_writer.add_document(doc).unwrap(); } { // 2 - let mut doc = Document::default(); - doc.add_text(text_field, "a b a b c"); + let doc = doc!(text_field=>"a b a b c"); index_writer.add_document(doc).unwrap(); } { // 3 - let mut doc = Document::default(); - doc.add_text(text_field, "c a b a d ga a"); + let doc = doc!(text_field=>"c a b a d ga a"); index_writer.add_document(doc).unwrap(); } { // 4 - let mut doc = Document::default(); - doc.add_text(text_field, "a b c"); + let doc = doc!(text_field=>"a b c"); index_writer.add_document(doc).unwrap(); } assert!(index_writer.commit().is_ok());