issue/50 Fixed VecPostings... Changed intersections.

This commit is contained in:
Paul Masurel
2016-11-03 14:28:14 +09:00
parent 59d1b9e2bb
commit a2c6ec93e0
14 changed files with 291 additions and 201 deletions

View File

@@ -95,7 +95,9 @@ impl FreqHandler {
pub fn positions(&self, idx: usize) -> &[u32] {
let start = self.positions_offsets[idx];
let stop = self.positions_offsets[idx + 1];
&self.positions[start..stop]
println!("{} -> {}", start, stop);
println!("{} {:?}", idx, &self.positions_offsets[..10]);
&self.positions[start..stop]
}
/// Decompresses a complete frequency block

View File

@@ -1,90 +1,64 @@
use postings::DocSet;
use postings::SkipResult;
use std::cmp::Ordering;
use DocId;
// TODO Find a way to specialize `IntersectionDocSet`
/// Creates a `DocSet` that iterator through the intersection of two `DocSet`s.
pub struct IntersectionDocSet<'a> {
left: Box<DocSet + 'a>,
right: Box<DocSet + 'a>,
pub struct IntersectionDocSet<TDocSet: DocSet> {
docsets: Vec<TDocSet>,
finished: bool,
doc: DocId,
}
impl<'a> IntersectionDocSet<'a> {
/// Intersect two `DocSet`s
fn from_pair(left: Box<DocSet + 'a>, right: Box<DocSet + 'a>) -> IntersectionDocSet<'a> {
impl<TDocSet: DocSet> From<Vec<TDocSet>> for IntersectionDocSet<TDocSet> {
fn from(docsets: Vec<TDocSet>) -> IntersectionDocSet<TDocSet> {
assert!(docsets.len() >= 2);
IntersectionDocSet {
left: left,
right: right,
docsets: docsets,
finished: false,
}
}
/// Intersect a list of `DocSet`s
pub fn new(mut postings: Vec<Box<DocSet + 'a>>) -> IntersectionDocSet<'a> {
let left = postings.pop().unwrap();
let right =
if postings.len() == 1 {
postings.pop().unwrap()
}
else {
Box::new(IntersectionDocSet::new(postings))
};
IntersectionDocSet::from_pair(left, right)
doc: DocId::max_value(),
}
}
}
impl<'a> DocSet for IntersectionDocSet<'a> {
impl<TDocSet: DocSet> DocSet for IntersectionDocSet<TDocSet> {
fn advance(&mut self,) -> bool {
if self.finished {
return false;
}
if !self.left.advance() {
self.finished = true;
return false;
}
if !self.right.advance() {
self.finished = true;
return false;
}
loop {
match self.left.doc().cmp(&self.right.doc()) {
Ordering::Equal => {
return true;
'outter: loop {
let doc_candidate = {
let mut first_docset = &mut self.docsets[0];
if !first_docset.advance() {
self.finished = true;
return false;
}
Ordering::Less => {
if !self.left.advance() {
self.finished = true;
return false;
}
}
Ordering::Greater => {
if !self.right.advance() {
first_docset.doc()
};
for docset_ord in 1..self.docsets.len() {
let docset: &mut TDocSet = &mut self.docsets[docset_ord];
match docset.skip_next(doc_candidate) {
SkipResult::End => {
self.finished = true;
return false;
}
SkipResult::OverStep => {
continue 'outter;
},
SkipResult::Reached => {}
}
}
self.doc = doc_candidate;
return true;
}
}
fn doc(&self,) -> DocId {
self.left.doc()
self.doc
}
}
/// Intersects a `Vec` of `DocSets`
pub fn intersection<'a, TDocSet: DocSet + 'a>(postings: Vec<TDocSet>) -> IntersectionDocSet<'a> {
let boxed_postings: Vec<Box<DocSet + 'a>> = postings
.into_iter()
.map(|postings: TDocSet| {
Box::new(postings) as Box<DocSet + 'a>
})
.collect();
IntersectionDocSet::new(boxed_postings)
}

View File

@@ -32,7 +32,6 @@ pub use self::vec_postings::VecPostings;
pub use self::chained_postings::ChainedPostings;
pub use self::segment_postings::SegmentPostings;
pub use self::intersection::intersection;
pub use self::intersection::IntersectionDocSet;
pub use self::freq_handler::FreqHandler;
@@ -50,6 +49,7 @@ mod tests {
use core::Index;
use std::iter;
use datastruct::stacker::Heap;
use query::{Query, TermQuery};
#[test]
@@ -73,7 +73,7 @@ mod tests {
}
#[test]
pub fn test_position_and_fieldnorm_write_fullstack() {
pub fn test_position_and_fieldnorm() {
let mut schema_builder = SchemaBuilder::default();
let text_field = schema_builder.add_text_field("text", TEXT);
let schema = schema_builder.build();
@@ -154,12 +154,43 @@ mod tests {
}
}
#[test]
pub fn test_position_and_fieldnorm2() {
let mut schema_builder = SchemaBuilder::default();
let text_field = schema_builder.add_text_field("text", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
{
let mut index_writer = index.writer_with_num_threads(1, 40_000_000).unwrap();
{
let mut doc = Document::default();
doc.add_text(text_field, "g b b d c g c");
index_writer.add_document(doc).unwrap();
}
{
let mut doc = Document::default();
doc.add_text(text_field, "g a b b a d c g c");
index_writer.add_document(doc).unwrap();
}
assert!(index_writer.commit().is_ok());
}
let term_query = TermQuery::from(Term::from_field_text(text_field, "a"));
let searcher = index.searcher();
let mut term_weight = term_query.specialized_weight(&*searcher);
term_weight.segment_postings_options = SegmentPostingsOption::FreqAndPositions;
let segment_reader = &searcher.segment_readers()[0];
let mut term_scorer = term_weight.specialized_scorer(segment_reader).unwrap();
assert!(term_scorer.advance());
assert_eq!(term_scorer.doc(), 1u32);
assert_eq!(term_scorer.postings().positions(), &[1u32, 4]);
}
#[test]
fn test_intersection() {
{
let left = Box::new(VecPostings::from(vec!(1, 3, 9)));
let right = Box::new(VecPostings::from(vec!(3, 4, 9, 18)));
let mut intersection = IntersectionDocSet::new(vec!(left, right));
let left = VecPostings::from(vec!(1, 3, 9));
let right = VecPostings::from(vec!(3, 4, 9, 18));
let mut intersection = IntersectionDocSet::from(vec!(left, right));
assert!(intersection.advance());
assert_eq!(intersection.doc(), 3);
assert!(intersection.advance());
@@ -167,10 +198,10 @@ mod tests {
assert!(!intersection.advance());
}
{
let a = Box::new(VecPostings::from(vec!(1, 3, 9)));
let b = Box::new(VecPostings::from(vec!(3, 4, 9, 18)));
let c = Box::new(VecPostings::from(vec!(1, 5, 9, 111)));
let mut intersection = IntersectionDocSet::new(vec!(a, b, c));
let a = VecPostings::from(vec!(1, 3, 9));
let b = VecPostings::from(vec!(3, 4, 9, 18));
let c = VecPostings::from(vec!(1, 5, 9, 111));
let mut intersection = IntersectionDocSet::from(vec!(a, b, c));
assert!(intersection.advance());
assert_eq!(intersection.doc(), 9);
assert!(!intersection.advance());

View File

@@ -5,7 +5,8 @@
/// Since decoding information is not free, this makes it possible to
/// avoid this extra cost when the information is not required.
/// For instance, positions are useful when running phrase queries
/// but useless in other queries,
/// but useless in other queries.
#[derive(Clone, Copy)]
pub enum SegmentPostingsOption {
/// Only the doc ids are decoded
NoFreq,

View File

@@ -37,59 +37,29 @@ impl DocSet for VecPostings {
}
fn skip_next(&mut self, target: DocId) -> SkipResult {
let mut start: usize = self.cursor.0;
match self.doc_ids[start].cmp(&target) {
Ordering::Equal => {
return SkipResult::Reached;
}
Ordering::Greater => {
if self.cursor.0 < self.doc_ids.len() {
return SkipResult::OverStep;
let next_id: usize = (self.cursor + Wrapping(
if self.cursor.0 == usize::max_value() {
1
}
else {
return SkipResult::End;
0
}
}
Ordering::Less => {
// see below
}
}
let mut end = self.doc_ids.len();
while end - start > 1 {
// find an upper bound
let mut jump = 1;
loop {
let jump_dest = start + jump;
if jump_dest >= end {
// we jump out of bounds
break;
)).0;
for i in next_id .. self.doc_ids.len() {
let doc: DocId = self.doc_ids[i];
match doc.cmp(&target) {
Ordering::Equal => {
self.cursor = Wrapping(i);
return SkipResult::Reached;
}
match self.doc_ids[jump_dest].cmp(&target) {
Ordering::Less => {
// still below the target, let's keep jumping.
start = jump_dest;
jump *= 2;
}
Ordering::Equal => {
self.cursor = Wrapping(jump_dest);
return SkipResult::Reached;
}
Ordering::Greater => {
end = jump_dest;
break;
}
}
Ordering::Greater => {
self.cursor = Wrapping(i);
return SkipResult::OverStep;
}
Ordering::Less => {}
}
}
self.cursor = Wrapping(start + 1);
if self.cursor.0 < self.doc_ids.len() {
SkipResult::OverStep
}
else {
SkipResult::End
}
SkipResult::End
}
}
@@ -132,5 +102,14 @@ pub mod tests {
assert_eq!(postings.doc(), 300u32);
assert_eq!(postings.skip_next(6000u32), SkipResult::End);
}
#[test]
pub fn test_vec_postings_skip_without_advance() {
let doc_ids: Vec<DocId> = (0u32..1024u32).map(|e| e*3).collect();
let mut postings = VecPostings::from(doc_ids);
assert_eq!(postings.skip_next(300u32), SkipResult::Reached);
assert_eq!(postings.doc(), 300u32);
assert_eq!(postings.skip_next(6000u32), SkipResult::End);
}
}

View File

@@ -77,6 +77,10 @@ impl<TScorer: Scorer> BooleanScorer<TScorer> {
}
}
pub fn num_subscorers(&self) -> usize {
self.postings.len()
}
/// Advances the head of our heap (the segment postings with the lowest doc)
/// It will also update the new current `DocId` as well as the term frequency
@@ -148,72 +152,3 @@ impl<TScorer: Scorer> Scorer for BooleanScorer<TScorer> {
}
}
#[cfg(test)]
mod tests {
use super::*;
use postings::{DocSet, VecPostings};
use query::Scorer;
use query::OccurFilter;
use query::term_query::TermScorer;
use query::Occur;
use fastfield::{U32FastFieldReader};
fn abs_diff(left: f32, right: f32) -> f32 {
(right - left).abs()
}
#[test]
pub fn test_boolean_scorer() {
let occurs = vec!(Occur::Should, Occur::Should);
let occur_filter = OccurFilter::new(&occurs);
let left_fieldnorms = U32FastFieldReader::from(vec!(100,200,300));
let left = VecPostings::from(vec!(1, 2, 3));
let left_scorer = TermScorer {
idf: 1f32,
fieldnorm_reader: left_fieldnorms,
postings: left,
};
let right_fieldnorms = U32FastFieldReader::from(vec!(15,25,35));
let right = VecPostings::from(vec!(1, 3, 8));
let right_scorer = TermScorer {
idf: 4f32,
fieldnorm_reader: right_fieldnorms,
postings: right,
};
let mut boolean_scorer = BooleanScorer::new(vec!(left_scorer, right_scorer), occur_filter);
assert_eq!(boolean_scorer.next(), Some(1u32));
assert!(abs_diff(boolean_scorer.score(), 0.8707107) < 0.001);
assert_eq!(boolean_scorer.next(), Some(2u32));
assert!(abs_diff(boolean_scorer.score(), 0.028867513) < 0.001f32);
assert_eq!(boolean_scorer.next(), Some(3u32));
assert_eq!(boolean_scorer.next(), Some(8u32));
assert!(abs_diff(boolean_scorer.score(), 0.5163978) < 0.001f32);
assert!(!boolean_scorer.advance());
}
#[test]
pub fn test_term_scorer() {
let left_fieldnorms = U32FastFieldReader::from(vec!(10, 4));
assert_eq!(left_fieldnorms.get(0), 10);
assert_eq!(left_fieldnorms.get(1), 4);
let left = VecPostings::from(vec!(1));
let mut left_scorer = TermScorer {
idf: 0.30685282,
fieldnorm_reader: left_fieldnorms,
postings: left,
};
left_scorer.advance();
assert!(abs_diff(left_scorer.score(), 0.15342641) < 0.001f32);
}
}

View File

@@ -7,4 +7,73 @@ mod score_combiner;
pub use self::boolean_query::BooleanQuery;
pub use self::boolean_clause::BooleanClause;
pub use self::boolean_scorer::BooleanScorer;
pub use self::score_combiner::ScoreCombiner;
pub use self::score_combiner::ScoreCombiner;
#[cfg(test)]
mod tests {
use super::*;
use postings::{DocSet, VecPostings};
use query::Scorer;
use query::OccurFilter;
use query::term_query::TermScorer;
use query::Occur;
use fastfield::{U32FastFieldReader};
fn abs_diff(left: f32, right: f32) -> f32 {
(right - left).abs()
}
#[test]
pub fn test_boolean_scorer() {
let occurs = vec!(Occur::Should, Occur::Should);
let occur_filter = OccurFilter::new(&occurs);
let left_fieldnorms = U32FastFieldReader::from(vec!(100,200,300));
let left = VecPostings::from(vec!(1, 2, 3));
let left_scorer = TermScorer {
idf: 1f32,
fieldnorm_reader: left_fieldnorms,
postings: left,
};
let right_fieldnorms = U32FastFieldReader::from(vec!(15,25,35));
let right = VecPostings::from(vec!(1, 3, 8));
let right_scorer = TermScorer {
idf: 4f32,
fieldnorm_reader: right_fieldnorms,
postings: right,
};
let mut boolean_scorer = BooleanScorer::new(vec!(left_scorer, right_scorer), occur_filter);
assert_eq!(boolean_scorer.next(), Some(1u32));
assert!(abs_diff(boolean_scorer.score(), 0.8707107) < 0.001);
assert_eq!(boolean_scorer.next(), Some(2u32));
assert!(abs_diff(boolean_scorer.score(), 0.028867513) < 0.001f32);
assert_eq!(boolean_scorer.next(), Some(3u32));
assert_eq!(boolean_scorer.next(), Some(8u32));
assert!(abs_diff(boolean_scorer.score(), 0.5163978) < 0.001f32);
assert!(!boolean_scorer.advance());
}
#[test]
pub fn test_term_scorer() {
let left_fieldnorms = U32FastFieldReader::from(vec!(10, 4));
assert_eq!(left_fieldnorms.get(0), 10);
assert_eq!(left_fieldnorms.get(1), 4);
let left = VecPostings::from(vec!(1));
let mut left_scorer = TermScorer {
idf: 0.30685282,
fieldnorm_reader: left_fieldnorms,
postings: left,
};
left_scorer.advance();
assert!(abs_diff(left_scorer.score(), 0.15342641) < 0.001f32);
}
}

View File

@@ -8,6 +8,7 @@ use core::searcher::Searcher;
use query::occur::Occur;
use query::occur_filter::OccurFilter;
use query::term_query::TermQuery;
use postings::SegmentPostingsOption;
/// Query involving one or more terms.
@@ -36,7 +37,11 @@ impl MultiTermQuery {
.collect();
let occur_filter = OccurFilter::new(&occurs);
let weights = term_queries.iter()
.map(|term_query| term_query.specialized_weight(searcher))
.map(|term_query| {
let mut term_weight = term_query.specialized_weight(searcher);
term_weight.segment_postings_options = SegmentPostingsOption::FreqAndPositions;
term_weight
})
.collect();
MultiTermWeight {
weights: weights,

View File

@@ -4,4 +4,65 @@ mod phrase_scorer;
pub use self::phrase_query::PhraseQuery;
pub use self::phrase_weight::PhraseWeight;
pub use self::phrase_scorer::PhraseScorer;
pub use self::phrase_scorer::PhraseScorer;
#[cfg(test)]
mod tests {
use super::*;
use query::Query;
use core::Index;
use schema::{Document, Term, SchemaBuilder, TEXT};
use collector::tests::TestCollector;
#[test]
pub fn test_phrase_query() {
let mut schema_builder = SchemaBuilder::default();
let text_field = schema_builder.add_text_field("text", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
{
let mut index_writer = index.writer_with_num_threads(1, 40_000_000).unwrap();
{
let mut doc = Document::default();
doc.add_text(text_field, "a b b d c g c");
index_writer.add_document(doc).unwrap();
}
// {
// let mut doc = Document::default();
// doc.add_text(text_field, "a b a b c");
// index_writer.add_document(doc).unwrap();
// }
// {
// let mut doc = Document::default();
// doc.add_text(text_field, "c a b a d ga a");
// index_writer.add_document(doc).unwrap();
// }
// {
// let mut doc = Document::default();
// doc.add_text(text_field, "a b c");
// index_writer.add_document(doc).unwrap();
// }
assert!(index_writer.commit().is_ok());
}
let mut test_collector = TestCollector::default();
let build_query = |texts: Vec<&str>| {
let terms: Vec<Term> = texts
.iter()
.map(|text| {
Term::from_field_text(text_field, text)
})
.collect();
PhraseQuery::from(terms)
};
let phrase_query = build_query(vec!("a", "b"));
let searcher = index.searcher();
phrase_query.search(&*searcher, &mut test_collector).expect("search should succeed");
assert_eq!(test_collector.docs(), vec!(0, 1, 2, 3));
}
}

View File

@@ -33,8 +33,9 @@ impl Query for PhraseQuery {
}
impl PhraseQuery {
pub fn new(terms: Vec<Term>) -> PhraseQuery {
impl From<Vec<Term>> for PhraseQuery {
fn from(terms: Vec<Term>) -> PhraseQuery {
assert!(terms.len() > 1);
let occur_terms: Vec<(Occur, Term)> = terms.into_iter()
.map(|term| (Occur::Must, term))

View File

@@ -7,24 +7,51 @@ use postings::Postings;
use DocId;
pub struct PhraseScorer<'a> {
pub all_term_scorer: BooleanScorer<TermScorer<SegmentPostings<'a>>>
pub all_term_scorer: BooleanScorer<TermScorer<SegmentPostings<'a>>>,
pub positions_offsets: Vec<u32>,
}
impl<'a> PhraseScorer<'a> {
fn phrase_match(&self) -> bool {
let scorers = self.all_term_scorer.scorers();
for scorer in scorers {
let positions = scorer.postings().positions();
println!("phrase_match");
let mut positions_arr: Vec<&[u32]> = self.all_term_scorer
.scorers()
.iter()
.map(|scorer| {
println!("{:?}", scorer.doc());
scorer.postings().positions()
})
.collect();
println!("positions arr {:?}", positions_arr);
let mut cur = 0;
'outer: loop {
for i in 0..positions_arr.len() {
let positions: &mut &[u32] = &mut positions_arr[i];
println!("{} {:?} {:?}", i, positions, self.positions_offsets);
if positions.len() == 0 {
return false;
}
let head_position = positions[0] + self.positions_offsets[i];
println!("cur: {}, head_position {}", cur, head_position);
while head_position < cur {
if positions.len() == 1 {
return false;
}
*positions = &(*positions)[1..];
}
if head_position != cur {
cur = head_position;
continue 'outer;
}
}
return true;
}
true
// self.all_term_scorer.positions();
// let positions =
}
}
impl<'a> DocSet for PhraseScorer<'a> {
fn advance(&mut self,) -> bool {
println!("docset advance");
while self.all_term_scorer.advance() {
if self.phrase_match() {
return true;

View File

@@ -20,8 +20,10 @@ impl From<MultiTermWeight> for PhraseWeight {
impl Weight for PhraseWeight {
fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> Result<Box<Scorer + 'a>> {
let all_term_scorer = try!(self.all_term_weight.specialized_scorer(reader));
let positions_offsets: Vec<u32> = (0u32..all_term_scorer.num_subscorers() as u32).collect();
Ok(box PhraseScorer {
all_term_scorer: all_term_scorer
all_term_scorer: all_term_scorer,
positions_offsets: positions_offsets
})
}
}

View File

@@ -3,6 +3,7 @@ use Result;
use super::term_weight::TermWeight;
use query::Query;
use query::Weight;
use postings::SegmentPostingsOption;
use Searcher;
use std::any::Any;
@@ -31,7 +32,8 @@ impl TermQuery {
TermWeight {
num_docs: searcher.num_docs(),
doc_freq: searcher.doc_freq(&self.term),
term: self.term.clone()
term: self.term.clone(),
segment_postings_options: SegmentPostingsOption::NoFreq,
}
}
}

View File

@@ -11,7 +11,8 @@ use Result;
pub struct TermWeight {
pub num_docs: u32,
pub doc_freq: u32,
pub term: Term,
pub term: Term,
pub segment_postings_options: SegmentPostingsOption,
}
@@ -35,7 +36,7 @@ impl TermWeight {
let fieldnorm_reader = try!(reader.get_fieldnorms_reader(field));
Ok(
reader
.read_postings(&self.term, SegmentPostingsOption::Freq)
.read_postings(&self.term, self.segment_postings_options)
.map(|segment_postings|
TermScorer {
idf: self.idf(),