issue/50 Test broken. PhraseQuery uses Intersection DocSet

This commit is contained in:
Paul Masurel
2016-11-03 16:56:09 +09:00
parent a2c6ec93e0
commit 627e4f1f60
10 changed files with 108 additions and 126 deletions

View File

@@ -29,8 +29,14 @@ pub trait DocSet {
///
/// SkipResult expresses whether the `target value` was reached, overstepped,
/// or if the `DocSet` was entirely consumed without finding any value
/// greater or equal to the `target`.
/// greater or equal to the `target`.
///
/// WARNING: Calling skip always advances the docset.
/// More specifically, if the docset is already positionned on the target
/// skipping will advance to the next position and return SkipResult::Overstep.
///
fn skip_next(&mut self, target: DocId) -> SkipResult {
self.advance();
loop {
match self.doc().cmp(&target) {
Ordering::Less => {

View File

@@ -95,8 +95,6 @@ impl FreqHandler {
pub fn positions(&self, idx: usize) -> &[u32] {
let start = self.positions_offsets[idx];
let stop = self.positions_offsets[idx + 1];
println!("{} -> {}", start, stop);
println!("{} {:?}", idx, &self.positions_offsets[..10]);
&self.positions[start..stop]
}

View File

@@ -1,6 +1,5 @@
use postings::DocSet;
use postings::SkipResult;
use std::cmp::Ordering;
use DocId;
// TODO Find a way to specialize `IntersectionDocSet`
@@ -23,38 +22,53 @@ impl<TDocSet: DocSet> From<Vec<TDocSet>> for IntersectionDocSet<TDocSet> {
}
}
impl<TDocSet: DocSet> IntersectionDocSet<TDocSet> {
pub fn docsets(&self) -> &[TDocSet] {
&self.docsets[..]
}
}
impl<TDocSet: DocSet> DocSet for IntersectionDocSet<TDocSet> {
fn advance(&mut self,) -> bool {
if self.finished {
return false;
}
'outter: loop {
let doc_candidate = {
let mut first_docset = &mut self.docsets[0];
if !first_docset.advance() {
let num_docsets = self.docsets.len();
let mut count_matching = 1;
let mut doc_candidate = {
let mut first_docset = &mut self.docsets[0];
if !first_docset.advance() {
self.finished = true;
return false;
}
first_docset.doc()
};
let mut ord = 1;
loop {
let mut doc_set = &mut self.docsets[ord];
match doc_set.skip_next(doc_candidate) {
SkipResult::Reached => {
count_matching += 1;
if count_matching == num_docsets {
self.doc = doc_candidate;
return true;
}
}
SkipResult::End => {
self.finished = true;
return false;
}
first_docset.doc()
};
for docset_ord in 1..self.docsets.len() {
let docset: &mut TDocSet = &mut self.docsets[docset_ord];
match docset.skip_next(doc_candidate) {
SkipResult::End => {
self.finished = true;
return false;
}
SkipResult::OverStep => {
continue 'outter;
},
SkipResult::Reached => {}
SkipResult::OverStep => {
count_matching = 1;
doc_candidate = doc_set.doc();
}
}
self.doc = doc_candidate;
return true;
ord += 1;
if ord == num_docsets {
ord = 0;
}
}
}

View File

@@ -49,7 +49,7 @@ mod tests {
use core::Index;
use std::iter;
use datastruct::stacker::Heap;
use query::{Query, TermQuery};
use query::TermQuery;
#[test]

View File

@@ -1,9 +1,8 @@
#![allow(dead_code)]
use DocId;
use postings::{Postings, DocSet, SkipResult, HasLen};
use postings::{Postings, DocSet, HasLen};
use std::num::Wrapping;
use std::cmp::Ordering;
const EMPTY_ARRAY: [u32; 0] = [];
@@ -35,32 +34,6 @@ impl DocSet for VecPostings {
fn doc(&self,) -> DocId {
self.doc_ids[self.cursor.0]
}
fn skip_next(&mut self, target: DocId) -> SkipResult {
let next_id: usize = (self.cursor + Wrapping(
if self.cursor.0 == usize::max_value() {
1
}
else {
0
}
)).0;
for i in next_id .. self.doc_ids.len() {
let doc: DocId = self.doc_ids[i];
match doc.cmp(&target) {
Ordering::Equal => {
self.cursor = Wrapping(i);
return SkipResult::Reached;
}
Ordering::Greater => {
self.cursor = Wrapping(i);
return SkipResult::OverStep;
}
Ordering::Less => {}
}
}
SkipResult::End
}
}
impl HasLen for VecPostings {
@@ -102,14 +75,6 @@ pub mod tests {
assert_eq!(postings.doc(), 300u32);
assert_eq!(postings.skip_next(6000u32), SkipResult::End);
}
#[test]
pub fn test_vec_postings_skip_without_advance() {
let doc_ids: Vec<DocId> = (0u32..1024u32).map(|e| e*3).collect();
let mut postings = VecPostings::from(doc_ids);
assert_eq!(postings.skip_next(300u32), SkipResult::Reached);
assert_eq!(postings.doc(), 300u32);
assert_eq!(postings.skip_next(6000u32), SkipResult::End);
}
}

View File

@@ -77,11 +77,6 @@ impl<TScorer: Scorer> BooleanScorer<TScorer> {
}
}
pub fn num_subscorers(&self) -> usize {
self.postings.len()
}
/// Advances the head of our heap (the segment postings with the lowest doc)
/// It will also update the new current `DocId` as well as the term frequency
/// associated with the segment postings.

View File

@@ -9,14 +9,13 @@ pub use self::phrase_scorer::PhraseScorer;
#[cfg(test)]
mod tests {
use super::*;
use query::Query;
use core::Index;
use schema::{Document, Term, SchemaBuilder, TEXT};
use collector::tests::TestCollector;
#[test]
pub fn test_phrase_query() {
@@ -26,43 +25,45 @@ mod tests {
let index = Index::create_in_ram(schema);
{
let mut index_writer = index.writer_with_num_threads(1, 40_000_000).unwrap();
{
{ // 0
let mut doc = Document::default();
doc.add_text(text_field, "b b b d c g c");
index_writer.add_document(doc).unwrap();
}
{ // 1
let mut doc = Document::default();
doc.add_text(text_field, "a b b d c g c");
index_writer.add_document(doc).unwrap();
}
// {
// let mut doc = Document::default();
// doc.add_text(text_field, "a b a b c");
// index_writer.add_document(doc).unwrap();
// }
// {
// let mut doc = Document::default();
// doc.add_text(text_field, "c a b a d ga a");
// index_writer.add_document(doc).unwrap();
// }
// {
// let mut doc = Document::default();
// doc.add_text(text_field, "a b c");
// index_writer.add_document(doc).unwrap();
// }
{ // 2
let mut doc = Document::default();
doc.add_text(text_field, "a b a b c");
index_writer.add_document(doc).unwrap();
}
{ // 3
let mut doc = Document::default();
doc.add_text(text_field, "c a b a d ga a");
index_writer.add_document(doc).unwrap();
}
{ // 4
let mut doc = Document::default();
doc.add_text(text_field, "a b c");
index_writer.add_document(doc).unwrap();
}
assert!(index_writer.commit().is_ok());
}
let mut test_collector = TestCollector::default();
let build_query = |texts: Vec<&str>| {
let terms: Vec<Term> = texts
.iter()
.map(|text| {
Term::from_field_text(text_field, text)
})
.map(|text| Term::from_field_text(text_field, text))
.collect();
PhraseQuery::from(terms)
};
let phrase_query = build_query(vec!("a", "b"));
let phrase_query = build_query(vec!("a", "b", "c"));
let searcher = index.searcher();
phrase_query.search(&*searcher, &mut test_collector).expect("search should succeed");
assert_eq!(test_collector.docs(), vec!(0, 1, 2, 3));
assert_eq!(test_collector.docs(), vec!(1, 2, 4));
}
}

View File

@@ -1,9 +1,7 @@
use schema::Term;
use query::Query;
use core::searcher::Searcher;
use query::Occur;
use super::PhraseWeight;
use query::MultiTermQuery;
use std::any::Any;
use query::Weight;
use Result;
@@ -11,7 +9,7 @@ use Result;
#[derive(Debug)]
pub struct PhraseQuery {
all_terms_query: MultiTermQuery,
phrase_terms: Vec<Term>,
}
impl Query for PhraseQuery {
@@ -27,21 +25,17 @@ impl Query for PhraseQuery {
///
/// See [Weight](./trait.Weight.html).
fn weight(&self, searcher: &Searcher) -> Result<Box<Weight>> {
let multi_term_weight = self.all_terms_query.specialized_weight(searcher);
Ok(box PhraseWeight::from(multi_term_weight))
Ok(box PhraseWeight::from(self.phrase_terms.clone()))
}
}
impl From<Vec<Term>> for PhraseQuery {
fn from(terms: Vec<Term>) -> PhraseQuery {
assert!(terms.len() > 1);
let occur_terms: Vec<(Occur, Term)> = terms.into_iter()
.map(|term| (Occur::Must, term))
.collect();
fn from(phrase_terms: Vec<Term>) -> PhraseQuery {
assert!(phrase_terms.len() > 1);
PhraseQuery {
all_terms_query: MultiTermQuery::from(occur_terms),
phrase_terms: phrase_terms,
}
}
}

View File

@@ -1,34 +1,34 @@
use query::Scorer;
use DocSet;
use query::term_query::TermScorer;
use query::boolean_query::BooleanScorer;
use postings::SegmentPostings;
use postings::Postings;
use postings::IntersectionDocSet;
use DocId;
pub struct PhraseScorer<'a> {
pub all_term_scorer: BooleanScorer<TermScorer<SegmentPostings<'a>>>,
pub intersection_docset: IntersectionDocSet<SegmentPostings<'a>>,
pub positions_offsets: Vec<u32>,
}
impl<'a> PhraseScorer<'a> {
fn phrase_match(&self) -> bool {
println!("phrase_match");
let mut positions_arr: Vec<&[u32]> = self.all_term_scorer
.scorers()
let mut positions_arr: Vec<&[u32]> = self.intersection_docset
.docsets()
.iter()
.map(|scorer| {
println!("{:?}", scorer.doc());
scorer.postings().positions()
.map(|posting| {
posting.positions()
})
.collect();
println!("positions arr {:?}", positions_arr);
let mut cur = 0;
'outer: loop {
for i in 0..positions_arr.len() {
println!("i {}", i);
let positions: &mut &[u32] = &mut positions_arr[i];
println!("{} {:?} {:?}", i, positions, self.positions_offsets);
if positions.len() == 0 {
println!("NOPE");
return false;
}
let head_position = positions[0] + self.positions_offsets[i];
@@ -51,9 +51,10 @@ impl<'a> PhraseScorer<'a> {
impl<'a> DocSet for PhraseScorer<'a> {
fn advance(&mut self,) -> bool {
println!("docset advance");
while self.all_term_scorer.advance() {
while self.intersection_docset.advance() {
println!("doc {}", self.intersection_docset.doc());
if self.phrase_match() {
println!("return {}", self.intersection_docset.doc());
return true;
}
}
@@ -61,14 +62,14 @@ impl<'a> DocSet for PhraseScorer<'a> {
}
fn doc(&self,) -> DocId {
self.all_term_scorer.doc()
self.intersection_docset.doc()
}
}
impl<'a> Scorer for PhraseScorer<'a> {
fn score(&self,) -> f32 {
self.all_term_scorer.score()
1f32
}
}

View File

@@ -1,29 +1,37 @@
use query::Weight;
use query::Scorer;
use schema::Term;
use postings::SegmentPostingsOption;
use core::SegmentReader;
use super::PhraseScorer;
use query::MultiTermWeight;
use postings::IntersectionDocSet;
use Result;
pub struct PhraseWeight {
all_term_weight: MultiTermWeight,
phrase_terms: Vec<Term>,
}
impl From<MultiTermWeight> for PhraseWeight {
fn from(all_term_weight: MultiTermWeight) -> PhraseWeight {
impl From<Vec<Term>> for PhraseWeight {
fn from(phrase_terms: Vec<Term>) -> PhraseWeight {
PhraseWeight {
all_term_weight: all_term_weight
phrase_terms: phrase_terms
}
}
}
impl Weight for PhraseWeight {
fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result<Box<Scorer + 'a>> {
let all_term_scorer = try!(self.all_term_weight.specialized_scorer(reader));
let positions_offsets: Vec<u32> = (0u32..all_term_scorer.num_subscorers() as u32).collect();
let mut term_postings_list = Vec::new();
for term in &self.phrase_terms {
let term_postings_option = reader.read_postings(term, SegmentPostingsOption::FreqAndPositions);
if let Some(term_postings) = term_postings_option {
term_postings_list.push(term_postings);
}
}
let positions_offsets: Vec<u32> = (0u32..self.phrase_terms.len() as u32).collect();
Ok(box PhraseScorer {
all_term_scorer: all_term_scorer,
positions_offsets: positions_offsets
intersection_docset: IntersectionDocSet::from(term_postings_list),
positions_offsets: positions_offsets,
})
}
}