diff --git a/src/postings/postings.rs b/src/postings/postings.rs index ca853efd5..f66c6434d 100644 --- a/src/postings/postings.rs +++ b/src/postings/postings.rs @@ -16,5 +16,5 @@ pub trait Postings: DocSet + 'static { /// Returns the list of positions of the term, expressed as a list of /// token ordinals. - fn positions_with_offset(&self, offset: u32, output: &mut Vec); + fn positions_with_offset(&mut self, offset: u32, output: &mut Vec); } diff --git a/src/postings/segment_postings.rs b/src/postings/segment_postings.rs index 3577ed2b4..410d67e27 100644 --- a/src/postings/segment_postings.rs +++ b/src/postings/segment_postings.rs @@ -35,11 +35,9 @@ impl PositionComputer { } pub fn add_skip(&mut self, num_skip: usize) { - self.position_to_skip = Some( - self.position_to_skip - .map(|prev_skip| prev_skip + num_skip) - .unwrap_or(0), - ); + self.position_to_skip = self.position_to_skip + .map(|prev_skip| prev_skip + num_skip) + .or(Some(0)); } pub fn positions(&mut self, offset: u32, output: &mut [u32]) { @@ -68,7 +66,7 @@ pub struct SegmentPostings { block_cursor: BlockSegmentPostings, cur: usize, delete_bitset: TDeleteSet, - position_computer: Option>, + position_computer: Option, } impl SegmentPostings { @@ -111,14 +109,7 @@ impl SegmentPostings { } impl SegmentPostings { - fn position_add_skip usize>(&self, num_skips_fn: F) { - if let Some(position_computer) = self.position_computer.as_ref() { - let num_skips = num_skips_fn(); - unsafe { - (*position_computer.get()).add_skip(num_skips); - } - } - } + /// Reads a Segment postings from an &[u8] @@ -132,13 +123,11 @@ impl SegmentPostings { delete_bitset: TDeleteSet, positions_stream_opt: Option, ) -> SegmentPostings { - let position_computer = - positions_stream_opt.map(|stream| UnsafeCell::new(PositionComputer::new(stream))); SegmentPostings { block_cursor: segment_block_postings, cur: COMPRESSION_BLOCK_SIZE, // cursor within the block delete_bitset, - position_computer, + position_computer: positions_stream_opt.map(PositionComputer::new), } } } @@ -149,7 +138,12 @@ impl DocSet for SegmentPostings { #[inline] fn advance(&mut self) -> bool { loop { - self.position_add_skip(|| self.term_freq() as usize); + { + if self.position_computer.is_some() { + let term_freq = self.term_freq() as usize; + self.position_computer.as_mut().unwrap().add_skip(term_freq); + } + } self.cur += 1; if self.cur >= self.block_cursor.block_len() { self.cur = 0; @@ -164,6 +158,7 @@ impl DocSet for SegmentPostings { } } + fn skip_next(&mut self, target: DocId) -> SkipResult { if !self.advance() { return SkipResult::End; @@ -185,17 +180,16 @@ impl DocSet for SegmentPostings { // so that position_add_skip will decide if it should // just set itself to Some(0) or effectively // add the term freq. - //let num_skips: u32 = ; - self.position_add_skip(|| { + if self.position_computer.is_some() { let freqs_skipped = &self.block_cursor.freqs()[self.cur..]; - let sum_freq: u32 = freqs_skipped.iter().cloned().sum(); - sum_freq as usize - }); - + let sum_freq: u32 = freqs_skipped.iter().sum() + self.position_computer.as_mut() + .unwrap() + .add_skip(sum_freq as usize); + } if !self.block_cursor.advance() { return SkipResult::End; } - self.cur = 0; } else { if target < current_doc { @@ -246,11 +240,13 @@ impl DocSet for SegmentPostings { // `doc` is now >= `target` let doc = block_docs[start]; - self.position_add_skip(|| { + if self.position_computer.is_some() { let freqs_skipped = &self.block_cursor.freqs()[self.cur..start]; let sum_freqs: u32 = freqs_skipped.iter().sum(); - sum_freqs as usize - }); + self.position_computer.as_mut() + .unwrap() + .add_skip(sum_freqs as usize); + } self.cur = start; @@ -312,8 +308,8 @@ impl Postings for SegmentPostings { self.block_cursor.freq(self.cur) } - fn positions_with_offset(&self, offset: u32, output: &mut Vec) { - if let Some(ref position_computer) = self.position_computer.as_ref() { + fn positions_with_offset(&mut self, offset: u32, output: &mut Vec) { + if self.position_computer.is_some() { let prev_capacity = output.capacity(); let term_freq = self.term_freq() as usize; if term_freq > prev_capacity { @@ -322,7 +318,7 @@ impl Postings for SegmentPostings { } unsafe { output.set_len(term_freq); - (&mut *position_computer.get()).positions(offset, &mut output[..]) + self.position_computer.as_mut().unwrap().positions(offset, &mut output[..]) } } else { unimplemented!("You may not read positions twice!"); @@ -608,3 +604,4 @@ mod tests { assert_eq!(block_segments.docs(), &[1, 3, 5]); } } + diff --git a/src/query/intersection.rs b/src/query/intersection.rs index 4a60fcb01..75598b237 100644 --- a/src/query/intersection.rs +++ b/src/query/intersection.rs @@ -86,12 +86,13 @@ impl Intersection { } } + impl Intersection { - pub fn docset(&self, ord: usize) -> &TDocSet { + pub fn docset_mut_specialized(&mut self, ord: usize) -> &mut TDocSet { match ord { - 0 => &self.left, - 1 => &self.right, - n => &self.others[n - 2] + 0 => &mut self.left, + 1 => &mut self.right, + n => &mut self.others[n - 2] } } } diff --git a/src/query/phrase_query/phrase_scorer.rs b/src/query/phrase_query/phrase_scorer.rs index 4accd1fb3..f77c63d68 100644 --- a/src/query/phrase_query/phrase_scorer.rs +++ b/src/query/phrase_query/phrase_scorer.rs @@ -18,7 +18,7 @@ impl PostingsWithOffset { } } - pub fn positions(&self, output: &mut Vec) { + pub fn positions(&mut self, output: &mut Vec) { self.postings.positions_with_offset(self.offset, output) } } @@ -86,11 +86,15 @@ impl PhraseScorer { } fn phrase_match(&mut self) -> bool { - // TODO early exit when we don't care about th phrase frequency - self.intersection_docset.docset(0).positions(&mut self.left); + // TODO early exit when we don't care about the phrase frequency + { + self.intersection_docset.docset_mut_specialized(0).positions(&mut self.left); + } let mut intersection_len = self.left.len(); for i in 1..self.num_docsets { - self.intersection_docset.docset(i).positions(&mut self.right); + { + self.intersection_docset.docset_mut_specialized(i).positions(&mut self.right); + } intersection_len = intersection_arr(&mut self.left[..intersection_len], &self.right[..]); if intersection_len == 0 { return false;