work in batches of docs (#1937)

* work in batches of docs

* add fill_buffer test
This commit is contained in:
PSeitz
2023-03-21 13:57:44 +08:00
committed by GitHub
parent 9e2faecf5b
commit 6a7a1106d6
15 changed files with 151 additions and 106 deletions

View File

@@ -72,7 +72,7 @@ pub trait ColumnValues<T: PartialOrd = u64>: Send + Sync {
let cutoff = indexes.len() - indexes.len() % step_size; let cutoff = indexes.len() - indexes.len() % step_size;
for idx in cutoff..indexes.len() { for idx in cutoff..indexes.len() {
output[idx] = self.get_val(indexes[idx] as u32); output[idx] = self.get_val(indexes[idx]);
} }
} }

View File

@@ -53,7 +53,7 @@ use crate::TantivyError;
/// into segment_size. /// into segment_size.
/// ///
/// Result type is [`BucketResult`](crate::aggregation::agg_result::BucketResult) with /// Result type is [`BucketResult`](crate::aggregation::agg_result::BucketResult) with
/// [`TermBucketEntry`](crate::aggregation::agg_result::BucketEntry) on the /// [`BucketEntry`](crate::aggregation::agg_result::BucketEntry) on the
/// `AggregationCollector`. /// `AggregationCollector`.
/// ///
/// Result type is /// Result type is
@@ -209,45 +209,6 @@ struct TermBuckets {
pub(crate) sub_aggs: FxHashMap<u64, Box<dyn SegmentAggregationCollector>>, pub(crate) sub_aggs: FxHashMap<u64, Box<dyn SegmentAggregationCollector>>,
} }
#[derive(Clone, Default)]
struct TermBucketEntry {
doc_count: u64,
sub_aggregations: Option<Box<dyn SegmentAggregationCollector>>,
}
impl Debug for TermBucketEntry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TermBucketEntry")
.field("doc_count", &self.doc_count)
.finish()
}
}
impl TermBucketEntry {
fn from_blueprint(blueprint: &Option<Box<dyn SegmentAggregationCollector>>) -> Self {
Self {
doc_count: 0,
sub_aggregations: blueprint.clone(),
}
}
pub(crate) fn into_intermediate_bucket_entry(
self,
agg_with_accessor: &AggregationsWithAccessor,
) -> crate::Result<IntermediateTermBucketEntry> {
let sub_aggregation = if let Some(sub_aggregation) = self.sub_aggregations {
sub_aggregation.into_intermediate_aggregations_result(agg_with_accessor)?
} else {
Default::default()
};
Ok(IntermediateTermBucketEntry {
doc_count: self.doc_count,
sub_aggregation,
})
}
}
impl TermBuckets { impl TermBuckets {
fn force_flush(&mut self, agg_with_accessor: &AggregationsWithAccessor) -> crate::Result<()> { fn force_flush(&mut self, agg_with_accessor: &AggregationsWithAccessor) -> crate::Result<()> {
for sub_aggregations in &mut self.sub_aggs.values_mut() { for sub_aggregations in &mut self.sub_aggs.values_mut() {
@@ -314,7 +275,7 @@ impl SegmentAggregationCollector for SegmentTermCollector {
if accessor.get_cardinality() == Cardinality::Full { if accessor.get_cardinality() == Cardinality::Full {
self.val_cache.resize(docs.len(), 0); self.val_cache.resize(docs.len(), 0);
accessor.values.get_vals(docs, &mut self.val_cache); accessor.values.get_vals(docs, &mut self.val_cache);
for (doc, term_id) in docs.iter().zip(self.val_cache.iter().cloned()) { for term_id in self.val_cache.iter().cloned() {
let entry = self.term_buckets.entries.entry(term_id).or_default(); let entry = self.term_buckets.entries.entry(term_id).or_default();
*entry += 1; *entry += 1;
} }
@@ -445,17 +406,19 @@ impl SegmentTermCollector {
let mut into_intermediate_bucket_entry = let mut into_intermediate_bucket_entry =
|id, doc_count| -> crate::Result<IntermediateTermBucketEntry> { |id, doc_count| -> crate::Result<IntermediateTermBucketEntry> {
let intermediate_entry = if let Some(blueprint) = self.blueprint.as_ref() { let intermediate_entry = if self.blueprint.as_ref().is_some() {
IntermediateTermBucketEntry { IntermediateTermBucketEntry {
doc_count, doc_count,
sub_aggregation: self sub_aggregation: self
.term_buckets .term_buckets
.sub_aggs .sub_aggs
.remove(&id) .remove(&id)
.expect(&format!( .unwrap_or_else(|| {
"Internal Error: could not find subaggregation for id {}", panic!(
id "Internal Error: could not find subaggregation for id {}",
)) id
)
})
.into_intermediate_aggregations_result( .into_intermediate_aggregations_result(
&agg_with_accessor.sub_aggregation, &agg_with_accessor.sub_aggregation,
)?, )?,
@@ -525,21 +488,11 @@ impl SegmentTermCollector {
pub(crate) trait GetDocCount { pub(crate) trait GetDocCount {
fn doc_count(&self) -> u64; fn doc_count(&self) -> u64;
} }
impl GetDocCount for (u32, TermBucketEntry) {
fn doc_count(&self) -> u64 {
self.1.doc_count
}
}
impl GetDocCount for (u64, u64) { impl GetDocCount for (u64, u64) {
fn doc_count(&self) -> u64 { fn doc_count(&self) -> u64 {
self.1 self.1
} }
} }
impl GetDocCount for (u64, TermBucketEntry) {
fn doc_count(&self) -> u64 {
self.1.doc_count
}
}
impl GetDocCount for (String, IntermediateTermBucketEntry) { impl GetDocCount for (String, IntermediateTermBucketEntry) {
fn doc_count(&self) -> u64 { fn doc_count(&self) -> u64 {
self.1.doc_count self.1.doc_count

View File

@@ -64,9 +64,8 @@ impl SegmentAggregationCollector for BufAggregationCollector {
docs: &[crate::DocId], docs: &[crate::DocId],
agg_with_accessor: &AggregationsWithAccessor, agg_with_accessor: &AggregationsWithAccessor,
) -> crate::Result<()> { ) -> crate::Result<()> {
for doc in docs { self.collector.collect_block(docs, agg_with_accessor)?;
self.collect(*doc, agg_with_accessor)?;
}
Ok(()) Ok(())
} }

View File

@@ -8,7 +8,7 @@ use super::segment_agg_result::{
}; };
use crate::aggregation::agg_req_with_accessor::get_aggs_with_accessor_and_validate; use crate::aggregation::agg_req_with_accessor::get_aggs_with_accessor_and_validate;
use crate::collector::{Collector, SegmentCollector}; use crate::collector::{Collector, SegmentCollector};
use crate::{SegmentReader, TantivyError}; use crate::{DocId, SegmentReader, TantivyError};
/// The default max bucket count, before the aggregation fails. /// The default max bucket count, before the aggregation fails.
pub const DEFAULT_BUCKET_LIMIT: u32 = 65000; pub const DEFAULT_BUCKET_LIMIT: u32 = 65000;
@@ -125,7 +125,7 @@ fn merge_fruits(
/// `AggregationSegmentCollector` does the aggregation collection on a segment. /// `AggregationSegmentCollector` does the aggregation collection on a segment.
pub struct AggregationSegmentCollector { pub struct AggregationSegmentCollector {
aggs_with_accessor: AggregationsWithAccessor, aggs_with_accessor: AggregationsWithAccessor,
result: BufAggregationCollector, agg_collector: BufAggregationCollector,
error: Option<TantivyError>, error: Option<TantivyError>,
} }
@@ -142,7 +142,7 @@ impl AggregationSegmentCollector {
BufAggregationCollector::new(build_segment_agg_collector(&aggs_with_accessor)?); BufAggregationCollector::new(build_segment_agg_collector(&aggs_with_accessor)?);
Ok(AggregationSegmentCollector { Ok(AggregationSegmentCollector {
aggs_with_accessor, aggs_with_accessor,
result, agg_collector: result,
error: None, error: None,
}) })
} }
@@ -152,11 +152,26 @@ impl SegmentCollector for AggregationSegmentCollector {
type Fruit = crate::Result<IntermediateAggregationResults>; type Fruit = crate::Result<IntermediateAggregationResults>;
#[inline] #[inline]
fn collect(&mut self, doc: crate::DocId, _score: crate::Score) { fn collect(&mut self, doc: DocId, _score: crate::Score) {
if self.error.is_some() { if self.error.is_some() {
return; return;
} }
if let Err(err) = self.result.collect(doc, &self.aggs_with_accessor) { if let Err(err) = self.agg_collector.collect(doc, &self.aggs_with_accessor) {
self.error = Some(err);
}
}
/// The query pushes the documents to the collector via this method.
///
/// Only valid for Collectors that ignore docs
fn collect_block(&mut self, docs: &[DocId]) {
if self.error.is_some() {
return;
}
if let Err(err) = self
.agg_collector
.collect_block(docs, &self.aggs_with_accessor)
{
self.error = Some(err); self.error = Some(err);
} }
} }
@@ -165,7 +180,7 @@ impl SegmentCollector for AggregationSegmentCollector {
if let Some(err) = self.error { if let Some(err) = self.error {
return Err(err); return Err(err);
} }
self.result.flush(&self.aggs_with_accessor)?; self.agg_collector.flush(&self.aggs_with_accessor)?;
Box::new(self.result).into_intermediate_aggregations_result(&self.aggs_with_accessor) Box::new(self.agg_collector).into_intermediate_aggregations_result(&self.aggs_with_accessor)
} }
} }

View File

@@ -180,9 +180,11 @@ pub trait Collector: Sync + Send {
})?; })?;
} }
(Some(alive_bitset), false) => { (Some(alive_bitset), false) => {
weight.for_each_no_score(reader, &mut |doc| { weight.for_each_no_score(reader, &mut |docs| {
if alive_bitset.is_alive(doc) { for doc in docs.iter().cloned() {
segment_collector.collect(doc, 0.0); if alive_bitset.is_alive(doc) {
segment_collector.collect(doc, 0.0);
}
} }
})?; })?;
} }
@@ -192,8 +194,8 @@ pub trait Collector: Sync + Send {
})?; })?;
} }
(None, false) => { (None, false) => {
weight.for_each_no_score(reader, &mut |doc| { weight.for_each_no_score(reader, &mut |docs| {
segment_collector.collect(doc, 0.0); segment_collector.collect_block(docs);
})?; })?;
} }
} }
@@ -270,6 +272,13 @@ pub trait SegmentCollector: 'static {
/// The query pushes the scored document to the collector via this method. /// The query pushes the scored document to the collector via this method.
fn collect(&mut self, doc: DocId, score: Score); fn collect(&mut self, doc: DocId, score: Score);
/// The query pushes the scored document to the collector via this method.
fn collect_block(&mut self, docs: &[DocId]) {
for doc in docs {
self.collect(*doc, 0.0);
}
}
/// Extract the fruit of the collection from the `SegmentCollector`. /// Extract the fruit of the collection from the `SegmentCollector`.
fn harvest(self) -> Self::Fruit; fn harvest(self) -> Self::Fruit;
} }

View File

@@ -9,6 +9,8 @@ use crate::DocId;
/// to compare `[u32; 4]`. /// to compare `[u32; 4]`.
pub const TERMINATED: DocId = i32::MAX as u32; pub const TERMINATED: DocId = i32::MAX as u32;
pub const BUFFER_LEN: usize = 64;
/// Represents an iterable set of sorted doc ids. /// Represents an iterable set of sorted doc ids.
pub trait DocSet: Send { pub trait DocSet: Send {
/// Goes to the next element. /// Goes to the next element.
@@ -59,7 +61,7 @@ pub trait DocSet: Send {
/// This method is only here for specific high-performance /// This method is only here for specific high-performance
/// use case where batching. The normal way to /// use case where batching. The normal way to
/// go through the `DocId`'s is to call `.advance()`. /// go through the `DocId`'s is to call `.advance()`.
fn fill_buffer(&mut self, buffer: &mut [DocId]) -> usize { fn fill_buffer(&mut self, buffer: &mut [DocId; BUFFER_LEN]) -> usize {
if self.doc() == TERMINATED { if self.doc() == TERMINATED {
return 0; return 0;
} }
@@ -149,6 +151,11 @@ impl<TDocSet: DocSet + ?Sized> DocSet for Box<TDocSet> {
unboxed.seek(target) unboxed.seek(target)
} }
fn fill_buffer(&mut self, buffer: &mut [DocId; BUFFER_LEN]) -> usize {
let unboxed: &mut TDocSet = self.borrow_mut();
unboxed.fill_buffer(buffer)
}
fn doc(&self) -> DocId { fn doc(&self) -> DocId {
let unboxed: &TDocSet = self.borrow(); let unboxed: &TDocSet = self.borrow();
unboxed.doc() unboxed.doc()

View File

@@ -94,10 +94,12 @@ fn compute_deleted_bitset(
// document that were inserted before it. // document that were inserted before it.
delete_op delete_op
.target .target
.for_each_no_score(segment_reader, &mut |doc_matching_delete_query| { .for_each_no_score(segment_reader, &mut |docs_matching_delete_query| {
if doc_opstamps.is_deleted(doc_matching_delete_query, delete_op.opstamp) { for doc_matching_delete_query in docs_matching_delete_query.iter().cloned() {
alive_bitset.remove(doc_matching_delete_query); if doc_opstamps.is_deleted(doc_matching_delete_query, delete_op.opstamp) {
might_have_changed = true; alive_bitset.remove(doc_matching_delete_query);
might_have_changed = true;
}
} }
})?; })?;
delete_cursor.advance(); delete_cursor.advance();

View File

@@ -1,5 +1,5 @@
use crate::core::SegmentReader; use crate::core::SegmentReader;
use crate::docset::{DocSet, TERMINATED}; use crate::docset::{DocSet, BUFFER_LEN, TERMINATED};
use crate::query::boost_query::BoostScorer; use crate::query::boost_query::BoostScorer;
use crate::query::explanation::does_not_match; use crate::query::explanation::does_not_match;
use crate::query::{EnableScoring, Explanation, Query, Scorer, Weight}; use crate::query::{EnableScoring, Explanation, Query, Scorer, Weight};
@@ -44,6 +44,7 @@ pub struct AllScorer {
} }
impl DocSet for AllScorer { impl DocSet for AllScorer {
#[inline(always)]
fn advance(&mut self) -> DocId { fn advance(&mut self) -> DocId {
if self.doc + 1 >= self.max_doc { if self.doc + 1 >= self.max_doc {
self.doc = TERMINATED; self.doc = TERMINATED;
@@ -53,6 +54,30 @@ impl DocSet for AllScorer {
self.doc self.doc
} }
fn fill_buffer(&mut self, buffer: &mut [DocId; BUFFER_LEN]) -> usize {
if self.doc() == TERMINATED {
return 0;
}
let is_safe_distance = self.doc() + (buffer.len() as u32) < self.max_doc;
if is_safe_distance {
let num_items = buffer.len();
for buffer_val in buffer {
*buffer_val = self.doc();
self.doc += 1;
}
num_items
} else {
for (i, buffer_val) in buffer.iter_mut().enumerate() {
*buffer_val = self.doc();
if self.advance() == TERMINATED {
return i + 1;
}
}
buffer.len()
}
}
#[inline(always)]
fn doc(&self) -> DocId { fn doc(&self) -> DocId {
self.doc self.doc
} }
@@ -71,8 +96,8 @@ impl Scorer for AllScorer {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::AllQuery; use super::AllQuery;
use crate::docset::TERMINATED; use crate::docset::{DocSet, BUFFER_LEN, TERMINATED};
use crate::query::{EnableScoring, Query}; use crate::query::{AllScorer, EnableScoring, Query};
use crate::schema::{Schema, TEXT}; use crate::schema::{Schema, TEXT};
use crate::Index; use crate::Index;
@@ -132,4 +157,22 @@ mod tests {
} }
Ok(()) Ok(())
} }
#[test]
pub fn test_fill_buffer() {
let mut postings = AllScorer {
doc: 0u32,
max_doc: BUFFER_LEN as u32 * 2 + 9,
};
let mut buffer = [0u32; BUFFER_LEN];
assert_eq!(postings.fill_buffer(&mut buffer), BUFFER_LEN);
for i in 0u32..BUFFER_LEN as u32 {
assert_eq!(buffer[i as usize], i);
}
assert_eq!(postings.fill_buffer(&mut buffer), BUFFER_LEN);
for i in 0u32..BUFFER_LEN as u32 {
assert_eq!(buffer[i as usize], i + BUFFER_LEN as u32);
}
assert_eq!(postings.fill_buffer(&mut buffer), 9);
}
} }

View File

@@ -45,6 +45,7 @@ impl From<BitSet> for BitSetDocSet {
} }
impl DocSet for BitSetDocSet { impl DocSet for BitSetDocSet {
#[inline]
fn advance(&mut self) -> DocId { fn advance(&mut self) -> DocId {
if let Some(lower) = self.cursor_tinybitset.pop_lowest() { if let Some(lower) = self.cursor_tinybitset.pop_lowest() {
self.doc = (self.cursor_bucket * 64u32) | lower; self.doc = (self.cursor_bucket * 64u32) | lower;

View File

@@ -1,11 +1,12 @@
use std::collections::HashMap; use std::collections::HashMap;
use crate::core::SegmentReader; use crate::core::SegmentReader;
use crate::docset::BUFFER_LEN;
use crate::postings::FreqReadingOption; use crate::postings::FreqReadingOption;
use crate::query::explanation::does_not_match; use crate::query::explanation::does_not_match;
use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner}; use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner};
use crate::query::term_query::TermScorer; use crate::query::term_query::TermScorer;
use crate::query::weight::{for_each_docset, for_each_pruning_scorer, for_each_scorer}; use crate::query::weight::{for_each_docset_buffered, for_each_pruning_scorer, for_each_scorer};
use crate::query::{ use crate::query::{
intersect_scorers, EmptyScorer, Exclude, Explanation, Occur, RequiredOptionalScorer, Scorer, intersect_scorers, EmptyScorer, Exclude, Explanation, Occur, RequiredOptionalScorer, Scorer,
Union, Weight, Union, Weight,
@@ -222,16 +223,18 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
fn for_each_no_score( fn for_each_no_score(
&self, &self,
reader: &SegmentReader, reader: &SegmentReader,
callback: &mut dyn FnMut(DocId), callback: &mut dyn FnMut(&[DocId]),
) -> crate::Result<()> { ) -> crate::Result<()> {
let scorer = self.complex_scorer(reader, 1.0, || DoNothingCombiner)?; let scorer = self.complex_scorer(reader, 1.0, || DoNothingCombiner)?;
let mut buffer = [0u32; BUFFER_LEN];
match scorer { match scorer {
SpecializedScorer::TermUnion(term_scorers) => { SpecializedScorer::TermUnion(term_scorers) => {
let mut union_scorer = Union::build(term_scorers, &self.score_combiner_fn); let mut union_scorer = Union::build(term_scorers, &self.score_combiner_fn);
for_each_docset(&mut union_scorer, callback); for_each_docset_buffered(&mut union_scorer, &mut buffer, callback);
} }
SpecializedScorer::Other(mut scorer) => { SpecializedScorer::Other(mut scorer) => {
for_each_docset(scorer.as_mut(), callback); for_each_docset_buffered(scorer.as_mut(), &mut buffer, callback);
} }
} }
Ok(()) Ok(())

View File

@@ -1,5 +1,6 @@
use std::fmt; use std::fmt;
use crate::docset::BUFFER_LEN;
use crate::fastfield::AliveBitSet; use crate::fastfield::AliveBitSet;
use crate::query::explanation::does_not_match; use crate::query::explanation::does_not_match;
use crate::query::{EnableScoring, Explanation, Query, Scorer, Weight}; use crate::query::{EnableScoring, Explanation, Query, Scorer, Weight};
@@ -106,7 +107,7 @@ impl<S: Scorer> DocSet for BoostScorer<S> {
self.underlying.seek(target) self.underlying.seek(target)
} }
fn fill_buffer(&mut self, buffer: &mut [DocId]) -> usize { fn fill_buffer(&mut self, buffer: &mut [DocId; BUFFER_LEN]) -> usize {
self.underlying.fill_buffer(buffer) self.underlying.fill_buffer(buffer)
} }

View File

@@ -1,5 +1,6 @@
use std::fmt; use std::fmt;
use crate::docset::BUFFER_LEN;
use crate::query::{EnableScoring, Explanation, Query, Scorer, Weight}; use crate::query::{EnableScoring, Explanation, Query, Scorer, Weight};
use crate::{DocId, DocSet, Score, SegmentReader, TantivyError, Term}; use crate::{DocId, DocSet, Score, SegmentReader, TantivyError, Term};
@@ -119,7 +120,7 @@ impl<TDocSet: DocSet> DocSet for ConstScorer<TDocSet> {
self.docset.seek(target) self.docset.seek(target)
} }
fn fill_buffer(&mut self, buffer: &mut [DocId]) -> usize { fn fill_buffer(&mut self, buffer: &mut [DocId; BUFFER_LEN]) -> usize {
self.docset.fill_buffer(buffer) self.docset.fill_buffer(buffer)
} }

View File

@@ -1,11 +1,11 @@
use super::term_scorer::TermScorer; use super::term_scorer::TermScorer;
use crate::core::SegmentReader; use crate::core::SegmentReader;
use crate::docset::DocSet; use crate::docset::{DocSet, BUFFER_LEN};
use crate::fieldnorm::FieldNormReader; use crate::fieldnorm::FieldNormReader;
use crate::postings::SegmentPostings; use crate::postings::SegmentPostings;
use crate::query::bm25::Bm25Weight; use crate::query::bm25::Bm25Weight;
use crate::query::explanation::does_not_match; use crate::query::explanation::does_not_match;
use crate::query::weight::{for_each_docset, for_each_scorer}; use crate::query::weight::{for_each_docset_buffered, for_each_scorer};
use crate::query::{Explanation, Scorer, Weight}; use crate::query::{Explanation, Scorer, Weight};
use crate::schema::IndexRecordOption; use crate::schema::IndexRecordOption;
use crate::{DocId, Score, Term}; use crate::{DocId, Score, Term};
@@ -61,10 +61,11 @@ impl Weight for TermWeight {
fn for_each_no_score( fn for_each_no_score(
&self, &self,
reader: &SegmentReader, reader: &SegmentReader,
callback: &mut dyn FnMut(DocId), callback: &mut dyn FnMut(&[DocId]),
) -> crate::Result<()> { ) -> crate::Result<()> {
let mut scorer = self.specialized_scorer(reader, 1.0)?; let mut scorer = self.specialized_scorer(reader, 1.0)?;
for_each_docset(&mut scorer, callback); let mut buffer = [0u32; BUFFER_LEN];
for_each_docset_buffered(&mut scorer, &mut buffer, callback);
Ok(()) Ok(())
} }

View File

@@ -53,7 +53,7 @@ impl HasLen for VecDocSet {
pub mod tests { pub mod tests {
use super::*; use super::*;
use crate::docset::DocSet; use crate::docset::{DocSet, BUFFER_LEN};
use crate::DocId; use crate::DocId;
#[test] #[test]
@@ -72,17 +72,17 @@ pub mod tests {
#[test] #[test]
pub fn test_fill_buffer() { pub fn test_fill_buffer() {
let doc_ids: Vec<DocId> = (1u32..210u32).collect(); let doc_ids: Vec<DocId> = (1u32..=(BUFFER_LEN as u32 * 2 + 9)).collect();
let mut postings = VecDocSet::from(doc_ids); let mut postings = VecDocSet::from(doc_ids);
let mut buffer = vec![1000u32; 100]; let mut buffer = [0u32; BUFFER_LEN];
assert_eq!(postings.fill_buffer(&mut buffer[..]), 100); assert_eq!(postings.fill_buffer(&mut buffer), BUFFER_LEN);
for i in 0u32..100u32 { for i in 0u32..BUFFER_LEN as u32 {
assert_eq!(buffer[i as usize], i + 1); assert_eq!(buffer[i as usize], i + 1);
} }
assert_eq!(postings.fill_buffer(&mut buffer[..]), 100); assert_eq!(postings.fill_buffer(&mut buffer), BUFFER_LEN);
for i in 0u32..100u32 { for i in 0u32..BUFFER_LEN as u32 {
assert_eq!(buffer[i as usize], i + 101); assert_eq!(buffer[i as usize], i + 1 + BUFFER_LEN as u32);
} }
assert_eq!(postings.fill_buffer(&mut buffer[..]), 9); assert_eq!(postings.fill_buffer(&mut buffer), 9);
} }
} }

View File

@@ -1,5 +1,6 @@
use super::Scorer; use super::Scorer;
use crate::core::SegmentReader; use crate::core::SegmentReader;
use crate::docset::BUFFER_LEN;
use crate::query::Explanation; use crate::query::Explanation;
use crate::{DocId, DocSet, Score, TERMINATED}; use crate::{DocId, DocSet, Score, TERMINATED};
@@ -18,11 +19,18 @@ pub(crate) fn for_each_scorer<TScorer: Scorer + ?Sized>(
/// Iterates through all of the documents matched by the DocSet /// Iterates through all of the documents matched by the DocSet
/// `DocSet`. /// `DocSet`.
pub(crate) fn for_each_docset<T: DocSet + ?Sized>(docset: &mut T, callback: &mut dyn FnMut(DocId)) { #[inline]
let mut doc = docset.doc(); pub(crate) fn for_each_docset_buffered<T: DocSet + ?Sized>(
while doc != TERMINATED { docset: &mut T,
callback(doc); buffer: &mut [DocId; BUFFER_LEN],
doc = docset.advance(); mut callback: impl FnMut(&[DocId]),
) {
loop {
let num_items = docset.fill_buffer(buffer);
callback(&buffer[..num_items]);
if num_items != buffer.len() {
break;
}
} }
} }
@@ -93,10 +101,12 @@ pub trait Weight: Send + Sync + 'static {
fn for_each_no_score( fn for_each_no_score(
&self, &self,
reader: &SegmentReader, reader: &SegmentReader,
callback: &mut dyn FnMut(DocId), callback: &mut dyn FnMut(&[DocId]),
) -> crate::Result<()> { ) -> crate::Result<()> {
let mut docset = self.scorer(reader, 1.0)?; let mut docset = self.scorer(reader, 1.0)?;
for_each_docset(docset.as_mut(), callback);
let mut buffer = [0u32; BUFFER_LEN];
for_each_docset_buffered(&mut docset, &mut buffer, callback);
Ok(()) Ok(())
} }