diff --git a/src/postings/docset.rs b/src/postings/docset.rs index faf839325..9c84c2dd8 100644 --- a/src/postings/docset.rs +++ b/src/postings/docset.rs @@ -29,8 +29,14 @@ pub trait DocSet { /// /// SkipResult expresses whether the `target value` was reached, overstepped, /// or if the `DocSet` was entirely consumed without finding any value - /// greater or equal to the `target`. + /// greater or equal to the `target`. + /// + /// WARNING: Calling skip always advances the docset. + /// More specifically, if the docset is already positionned on the target + /// skipping will advance to the next position and return SkipResult::Overstep. + /// fn skip_next(&mut self, target: DocId) -> SkipResult { + self.advance(); loop { match self.doc().cmp(&target) { Ordering::Less => { diff --git a/src/postings/freq_handler.rs b/src/postings/freq_handler.rs index 660ac4977..cb4d6d92e 100644 --- a/src/postings/freq_handler.rs +++ b/src/postings/freq_handler.rs @@ -95,8 +95,6 @@ impl FreqHandler { pub fn positions(&self, idx: usize) -> &[u32] { let start = self.positions_offsets[idx]; let stop = self.positions_offsets[idx + 1]; - println!("{} -> {}", start, stop); - println!("{} {:?}", idx, &self.positions_offsets[..10]); &self.positions[start..stop] } diff --git a/src/postings/intersection.rs b/src/postings/intersection.rs index 2f49910ed..6946a652e 100644 --- a/src/postings/intersection.rs +++ b/src/postings/intersection.rs @@ -1,6 +1,5 @@ use postings::DocSet; use postings::SkipResult; -use std::cmp::Ordering; use DocId; // TODO Find a way to specialize `IntersectionDocSet` @@ -23,38 +22,53 @@ impl From> for IntersectionDocSet { } } +impl IntersectionDocSet { + pub fn docsets(&self) -> &[TDocSet] { + &self.docsets[..] + } +} + impl DocSet for IntersectionDocSet { - + fn advance(&mut self,) -> bool { if self.finished { return false; } - - 'outter: loop { - let doc_candidate = { - let mut first_docset = &mut self.docsets[0]; - if !first_docset.advance() { + let num_docsets = self.docsets.len(); + let mut count_matching = 1; + let mut doc_candidate = { + let mut first_docset = &mut self.docsets[0]; + if !first_docset.advance() { + self.finished = true; + return false; + } + first_docset.doc() + }; + let mut ord = 1; + loop { + let mut doc_set = &mut self.docsets[ord]; + match doc_set.skip_next(doc_candidate) { + SkipResult::Reached => { + count_matching += 1; + if count_matching == num_docsets { + self.doc = doc_candidate; + return true; + } + } + SkipResult::End => { self.finished = true; return false; } - first_docset.doc() - }; - for docset_ord in 1..self.docsets.len() { - let docset: &mut TDocSet = &mut self.docsets[docset_ord]; - match docset.skip_next(doc_candidate) { - SkipResult::End => { - self.finished = true; - return false; - } - SkipResult::OverStep => { - continue 'outter; - }, - SkipResult::Reached => {} + SkipResult::OverStep => { + count_matching = 1; + doc_candidate = doc_set.doc(); } } - self.doc = doc_candidate; - return true; + ord += 1; + if ord == num_docsets { + ord = 0; + } } } diff --git a/src/postings/mod.rs b/src/postings/mod.rs index 7992f629e..a0aa35741 100644 --- a/src/postings/mod.rs +++ b/src/postings/mod.rs @@ -49,7 +49,7 @@ mod tests { use core::Index; use std::iter; use datastruct::stacker::Heap; - use query::{Query, TermQuery}; + use query::TermQuery; #[test] diff --git a/src/postings/vec_postings.rs b/src/postings/vec_postings.rs index 02bcf699d..8704915e0 100644 --- a/src/postings/vec_postings.rs +++ b/src/postings/vec_postings.rs @@ -1,9 +1,8 @@ #![allow(dead_code)] use DocId; -use postings::{Postings, DocSet, SkipResult, HasLen}; +use postings::{Postings, DocSet, HasLen}; use std::num::Wrapping; -use std::cmp::Ordering; const EMPTY_ARRAY: [u32; 0] = []; @@ -35,32 +34,6 @@ impl DocSet for VecPostings { fn doc(&self,) -> DocId { self.doc_ids[self.cursor.0] } - - fn skip_next(&mut self, target: DocId) -> SkipResult { - let next_id: usize = (self.cursor + Wrapping( - if self.cursor.0 == usize::max_value() { - 1 - } - else { - 0 - } - )).0; - for i in next_id .. self.doc_ids.len() { - let doc: DocId = self.doc_ids[i]; - match doc.cmp(&target) { - Ordering::Equal => { - self.cursor = Wrapping(i); - return SkipResult::Reached; - } - Ordering::Greater => { - self.cursor = Wrapping(i); - return SkipResult::OverStep; - } - Ordering::Less => {} - } - } - SkipResult::End - } } impl HasLen for VecPostings { @@ -102,14 +75,6 @@ pub mod tests { assert_eq!(postings.doc(), 300u32); assert_eq!(postings.skip_next(6000u32), SkipResult::End); } - - #[test] - pub fn test_vec_postings_skip_without_advance() { - let doc_ids: Vec = (0u32..1024u32).map(|e| e*3).collect(); - let mut postings = VecPostings::from(doc_ids); - assert_eq!(postings.skip_next(300u32), SkipResult::Reached); - assert_eq!(postings.doc(), 300u32); - assert_eq!(postings.skip_next(6000u32), SkipResult::End); - } + } diff --git a/src/query/boolean_query/boolean_scorer.rs b/src/query/boolean_query/boolean_scorer.rs index 4be2cd7c9..e6c17af69 100644 --- a/src/query/boolean_query/boolean_scorer.rs +++ b/src/query/boolean_query/boolean_scorer.rs @@ -77,11 +77,6 @@ impl BooleanScorer { } } - pub fn num_subscorers(&self) -> usize { - self.postings.len() - } - - /// Advances the head of our heap (the segment postings with the lowest doc) /// It will also update the new current `DocId` as well as the term frequency /// associated with the segment postings. diff --git a/src/query/phrase_query/mod.rs b/src/query/phrase_query/mod.rs index 12fe1f6b5..855d198c2 100644 --- a/src/query/phrase_query/mod.rs +++ b/src/query/phrase_query/mod.rs @@ -9,14 +9,13 @@ pub use self::phrase_scorer::PhraseScorer; #[cfg(test)] mod tests { - + use super::*; use query::Query; use core::Index; use schema::{Document, Term, SchemaBuilder, TEXT}; use collector::tests::TestCollector; - #[test] pub fn test_phrase_query() { @@ -26,43 +25,45 @@ mod tests { let index = Index::create_in_ram(schema); { let mut index_writer = index.writer_with_num_threads(1, 40_000_000).unwrap(); - { + { // 0 + let mut doc = Document::default(); + doc.add_text(text_field, "b b b d c g c"); + index_writer.add_document(doc).unwrap(); + } + { // 1 let mut doc = Document::default(); doc.add_text(text_field, "a b b d c g c"); index_writer.add_document(doc).unwrap(); } - // { - // let mut doc = Document::default(); - // doc.add_text(text_field, "a b a b c"); - // index_writer.add_document(doc).unwrap(); - // } - // { - // let mut doc = Document::default(); - // doc.add_text(text_field, "c a b a d ga a"); - // index_writer.add_document(doc).unwrap(); - // } - // { - // let mut doc = Document::default(); - // doc.add_text(text_field, "a b c"); - // index_writer.add_document(doc).unwrap(); - // } + { // 2 + let mut doc = Document::default(); + doc.add_text(text_field, "a b a b c"); + index_writer.add_document(doc).unwrap(); + } + { // 3 + let mut doc = Document::default(); + doc.add_text(text_field, "c a b a d ga a"); + index_writer.add_document(doc).unwrap(); + } + { // 4 + let mut doc = Document::default(); + doc.add_text(text_field, "a b c"); + index_writer.add_document(doc).unwrap(); + } assert!(index_writer.commit().is_ok()); } let mut test_collector = TestCollector::default(); let build_query = |texts: Vec<&str>| { let terms: Vec = texts .iter() - .map(|text| { - Term::from_field_text(text_field, text) - }) + .map(|text| Term::from_field_text(text_field, text)) .collect(); PhraseQuery::from(terms) }; - let phrase_query = build_query(vec!("a", "b")); + let phrase_query = build_query(vec!("a", "b", "c")); let searcher = index.searcher(); phrase_query.search(&*searcher, &mut test_collector).expect("search should succeed"); - assert_eq!(test_collector.docs(), vec!(0, 1, 2, 3)); + assert_eq!(test_collector.docs(), vec!(1, 2, 4)); } - } diff --git a/src/query/phrase_query/phrase_query.rs b/src/query/phrase_query/phrase_query.rs index 9058f96b8..816fd4dbc 100644 --- a/src/query/phrase_query/phrase_query.rs +++ b/src/query/phrase_query/phrase_query.rs @@ -1,9 +1,7 @@ use schema::Term; use query::Query; use core::searcher::Searcher; -use query::Occur; use super::PhraseWeight; -use query::MultiTermQuery; use std::any::Any; use query::Weight; use Result; @@ -11,7 +9,7 @@ use Result; #[derive(Debug)] pub struct PhraseQuery { - all_terms_query: MultiTermQuery, + phrase_terms: Vec, } impl Query for PhraseQuery { @@ -27,21 +25,17 @@ impl Query for PhraseQuery { /// /// See [Weight](./trait.Weight.html). fn weight(&self, searcher: &Searcher) -> Result> { - let multi_term_weight = self.all_terms_query.specialized_weight(searcher); - Ok(box PhraseWeight::from(multi_term_weight)) + Ok(box PhraseWeight::from(self.phrase_terms.clone())) } } impl From> for PhraseQuery { - fn from(terms: Vec) -> PhraseQuery { - assert!(terms.len() > 1); - let occur_terms: Vec<(Occur, Term)> = terms.into_iter() - .map(|term| (Occur::Must, term)) - .collect(); + fn from(phrase_terms: Vec) -> PhraseQuery { + assert!(phrase_terms.len() > 1); PhraseQuery { - all_terms_query: MultiTermQuery::from(occur_terms), + phrase_terms: phrase_terms, } } } diff --git a/src/query/phrase_query/phrase_scorer.rs b/src/query/phrase_query/phrase_scorer.rs index b80422b94..725099f12 100644 --- a/src/query/phrase_query/phrase_scorer.rs +++ b/src/query/phrase_query/phrase_scorer.rs @@ -1,34 +1,34 @@ use query::Scorer; use DocSet; -use query::term_query::TermScorer; -use query::boolean_query::BooleanScorer; use postings::SegmentPostings; use postings::Postings; +use postings::IntersectionDocSet; use DocId; pub struct PhraseScorer<'a> { - pub all_term_scorer: BooleanScorer>>, + pub intersection_docset: IntersectionDocSet>, pub positions_offsets: Vec, } impl<'a> PhraseScorer<'a> { fn phrase_match(&self) -> bool { - println!("phrase_match"); - let mut positions_arr: Vec<&[u32]> = self.all_term_scorer - .scorers() + let mut positions_arr: Vec<&[u32]> = self.intersection_docset + .docsets() .iter() - .map(|scorer| { - println!("{:?}", scorer.doc()); - scorer.postings().positions() + .map(|posting| { + posting.positions() }) .collect(); println!("positions arr {:?}", positions_arr); + + let mut cur = 0; 'outer: loop { for i in 0..positions_arr.len() { + println!("i {}", i); let positions: &mut &[u32] = &mut positions_arr[i]; - println!("{} {:?} {:?}", i, positions, self.positions_offsets); if positions.len() == 0 { + println!("NOPE"); return false; } let head_position = positions[0] + self.positions_offsets[i]; @@ -51,9 +51,10 @@ impl<'a> PhraseScorer<'a> { impl<'a> DocSet for PhraseScorer<'a> { fn advance(&mut self,) -> bool { - println!("docset advance"); - while self.all_term_scorer.advance() { + while self.intersection_docset.advance() { + println!("doc {}", self.intersection_docset.doc()); if self.phrase_match() { + println!("return {}", self.intersection_docset.doc()); return true; } } @@ -61,14 +62,14 @@ impl<'a> DocSet for PhraseScorer<'a> { } fn doc(&self,) -> DocId { - self.all_term_scorer.doc() + self.intersection_docset.doc() } } impl<'a> Scorer for PhraseScorer<'a> { fn score(&self,) -> f32 { - self.all_term_scorer.score() + 1f32 } } diff --git a/src/query/phrase_query/phrase_weight.rs b/src/query/phrase_query/phrase_weight.rs index b3d1e002c..234e22073 100644 --- a/src/query/phrase_query/phrase_weight.rs +++ b/src/query/phrase_query/phrase_weight.rs @@ -1,29 +1,37 @@ use query::Weight; use query::Scorer; +use schema::Term; +use postings::SegmentPostingsOption; use core::SegmentReader; use super::PhraseScorer; -use query::MultiTermWeight; +use postings::IntersectionDocSet; use Result; pub struct PhraseWeight { - all_term_weight: MultiTermWeight, + phrase_terms: Vec, } -impl From for PhraseWeight { - fn from(all_term_weight: MultiTermWeight) -> PhraseWeight { +impl From> for PhraseWeight { + fn from(phrase_terms: Vec) -> PhraseWeight { PhraseWeight { - all_term_weight: all_term_weight + phrase_terms: phrase_terms } } } impl Weight for PhraseWeight { fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result> { - let all_term_scorer = try!(self.all_term_weight.specialized_scorer(reader)); - let positions_offsets: Vec = (0u32..all_term_scorer.num_subscorers() as u32).collect(); + let mut term_postings_list = Vec::new(); + for term in &self.phrase_terms { + let term_postings_option = reader.read_postings(term, SegmentPostingsOption::FreqAndPositions); + if let Some(term_postings) = term_postings_option { + term_postings_list.push(term_postings); + } + } + let positions_offsets: Vec = (0u32..self.phrase_terms.len() as u32).collect(); Ok(box PhraseScorer { - all_term_scorer: all_term_scorer, - positions_offsets: positions_offsets + intersection_docset: IntersectionDocSet::from(term_postings_list), + positions_offsets: positions_offsets, }) } }