split term collection count and sub_agg (#1921)

use unrolled ColumnValues::get_vals
This commit is contained in:
PSeitz
2023-03-13 11:37:41 +08:00
committed by GitHub
parent 61cfd8dc57
commit 8459efa32c
9 changed files with 124 additions and 44 deletions

2
.gitignore vendored
View File

@@ -13,3 +13,5 @@ benchmark
.idea .idea
trace.dat trace.dat
cargo-timing* cargo-timing*
control
variable

View File

@@ -58,10 +58,21 @@ pub trait ColumnValues<T: PartialOrd = u64>: Send + Sync {
/// # Panics /// # Panics
/// ///
/// May panic if `idx` is greater than the column length. /// May panic if `idx` is greater than the column length.
fn get_vals(&self, idxs: &[u32], output: &mut [T]) { fn get_vals(&self, indexes: &[u32], output: &mut [T]) {
assert!(idxs.len() == output.len()); assert!(indexes.len() == output.len());
for (out, &idx) in output.iter_mut().zip(idxs.iter()) { let out_and_idx_chunks = output.chunks_exact_mut(4).zip(indexes.chunks_exact(4));
*out = self.get_val(idx); for (out_x4, idx_x4) in out_and_idx_chunks {
out_x4[0] = self.get_val(idx_x4[0]);
out_x4[1] = self.get_val(idx_x4[1]);
out_x4[2] = self.get_val(idx_x4[2]);
out_x4[3] = self.get_val(idx_x4[3]);
}
let step_size = 4;
let cutoff = indexes.len() - indexes.len() % step_size;
for idx in cutoff..indexes.len() {
output[idx] = self.get_val(indexes[idx] as u32);
} }
} }

View File

@@ -50,7 +50,7 @@ where
Input: PartialOrd + Send + Debug + Sync + Clone, Input: PartialOrd + Send + Debug + Sync + Clone,
Output: PartialOrd + Send + Debug + Sync + Clone, Output: PartialOrd + Send + Debug + Sync + Clone,
{ {
#[inline] #[inline(always)]
fn get_val(&self, idx: u32) -> Output { fn get_val(&self, idx: u32) -> Output {
let from_val = self.from_column.get_val(idx); let from_val = self.from_column.get_val(idx);
self.monotonic_mapping.mapping(from_val) self.monotonic_mapping.mapping(from_val)

View File

@@ -99,14 +99,28 @@ pub(crate) fn create_and_validate<TColumnCodec: ColumnCodec>(
let reader = TColumnCodec::load(OwnedBytes::new(buffer)).unwrap(); let reader = TColumnCodec::load(OwnedBytes::new(buffer)).unwrap();
assert_eq!(reader.num_vals(), vals.len() as u32); assert_eq!(reader.num_vals(), vals.len() as u32);
let mut buffer = Vec::new();
for (doc, orig_val) in vals.iter().copied().enumerate() { for (doc, orig_val) in vals.iter().copied().enumerate() {
let val = reader.get_val(doc as u32); let val = reader.get_val(doc as u32);
assert_eq!( assert_eq!(
val, orig_val, val, orig_val,
"val `{val}` does not match orig_val {orig_val:?}, in data set {name}, data `{vals:?}`", "val `{val}` does not match orig_val {orig_val:?}, in data set {name}, data `{vals:?}`",
); );
buffer.resize(1, 0);
reader.get_vals(&[doc as u32], &mut buffer);
let val = buffer[0];
assert_eq!(
val, orig_val,
"val `{val}` does not match orig_val {orig_val:?}, in data set {name}, data `{vals:?}`",
);
} }
let all_docs: Vec<u32> = (0..vals.len() as u32).collect();
buffer.resize(all_docs.len(), 0);
reader.get_vals(&all_docs, &mut buffer);
assert_eq!(vals, buffer);
if !vals.is_empty() { if !vals.is_empty() {
let test_rand_idx = rand::thread_rng().gen_range(0..=vals.len() - 1); let test_rand_idx = rand::thread_rng().gen_range(0..=vals.len() - 1);
let expected_positions: Vec<u32> = vals let expected_positions: Vec<u32> = vals

View File

@@ -230,6 +230,7 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
}) })
} }
#[inline]
fn collect( fn collect(
&mut self, &mut self,
doc: crate::DocId, doc: crate::DocId,
@@ -238,6 +239,7 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
self.collect_block(&[doc], agg_with_accessor) self.collect_block(&[doc], agg_with_accessor)
} }
#[inline]
fn collect_block( fn collect_block(
&mut self, &mut self,
docs: &[crate::DocId], docs: &[crate::DocId],

View File

@@ -208,6 +208,7 @@ impl SegmentAggregationCollector for SegmentRangeCollector {
}) })
} }
#[inline]
fn collect( fn collect(
&mut self, &mut self,
doc: crate::DocId, doc: crate::DocId,
@@ -216,6 +217,7 @@ impl SegmentAggregationCollector for SegmentRangeCollector {
self.collect_block(&[doc], agg_with_accessor) self.collect_block(&[doc], agg_with_accessor)
} }
#[inline]
fn collect_block( fn collect_block(
&mut self, &mut self,
docs: &[crate::DocId], docs: &[crate::DocId],

View File

@@ -205,7 +205,8 @@ impl TermsAggregationInternal {
#[derive(Clone, Debug, Default)] #[derive(Clone, Debug, Default)]
/// Container to store term_ids/or u64 values and their buckets. /// Container to store term_ids/or u64 values and their buckets.
struct TermBuckets { struct TermBuckets {
pub(crate) entries: FxHashMap<u64, TermBucketEntry>, pub(crate) entries: FxHashMap<u64, u64>,
pub(crate) sub_aggs: FxHashMap<u64, Box<dyn SegmentAggregationCollector>>,
} }
#[derive(Clone, Default)] #[derive(Clone, Default)]
@@ -249,10 +250,8 @@ impl TermBucketEntry {
impl TermBuckets { impl TermBuckets {
fn force_flush(&mut self, agg_with_accessor: &AggregationsWithAccessor) -> crate::Result<()> { fn force_flush(&mut self, agg_with_accessor: &AggregationsWithAccessor) -> crate::Result<()> {
for entry in &mut self.entries.values_mut() { for sub_aggregations in &mut self.sub_aggs.values_mut() {
if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() { sub_aggregations.as_mut().flush(agg_with_accessor)?;
sub_aggregations.flush(agg_with_accessor)?;
}
} }
Ok(()) Ok(())
} }
@@ -268,6 +267,7 @@ pub struct SegmentTermCollector {
blueprint: Option<Box<dyn SegmentAggregationCollector>>, blueprint: Option<Box<dyn SegmentAggregationCollector>>,
field_type: ColumnType, field_type: ColumnType,
accessor_idx: usize, accessor_idx: usize,
val_cache: Vec<u64>,
} }
pub(crate) fn get_agg_name_and_property(name: &str) -> (&str, &str) { pub(crate) fn get_agg_name_and_property(name: &str) -> (&str, &str) {
@@ -292,6 +292,7 @@ impl SegmentAggregationCollector for SegmentTermCollector {
}) })
} }
#[inline]
fn collect( fn collect(
&mut self, &mut self,
doc: crate::DocId, doc: crate::DocId,
@@ -300,6 +301,7 @@ impl SegmentAggregationCollector for SegmentTermCollector {
self.collect_block(&[doc], agg_with_accessor) self.collect_block(&[doc], agg_with_accessor)
} }
#[inline]
fn collect_block( fn collect_block(
&mut self, &mut self,
docs: &[crate::DocId], docs: &[crate::DocId],
@@ -310,28 +312,35 @@ impl SegmentAggregationCollector for SegmentTermCollector {
&agg_with_accessor.buckets.values[self.accessor_idx].sub_aggregation; &agg_with_accessor.buckets.values[self.accessor_idx].sub_aggregation;
if accessor.get_cardinality() == Cardinality::Full { if accessor.get_cardinality() == Cardinality::Full {
for doc in docs { self.val_cache.resize(docs.len(), 0);
let term_id = accessor.values.get_val(*doc); accessor.values.get_vals(docs, &mut self.val_cache);
let entry = self for (doc, term_id) in docs.iter().zip(self.val_cache.iter().cloned()) {
.term_buckets let entry = self.term_buckets.entries.entry(term_id).or_default();
.entries *entry += 1;
.entry(term_id) }
.or_insert_with(|| TermBucketEntry::from_blueprint(&self.blueprint)); // has subagg
entry.doc_count += 1; if let Some(blueprint) = self.blueprint.as_ref() {
if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() { for (doc, term_id) in docs.iter().zip(self.val_cache.iter().cloned()) {
let sub_aggregations = self
.term_buckets
.sub_aggs
.entry(term_id)
.or_insert_with(|| blueprint.clone());
sub_aggregations.collect(*doc, sub_aggregation_accessor)?; sub_aggregations.collect(*doc, sub_aggregation_accessor)?;
} }
} }
} else { } else {
for doc in docs { for doc in docs {
for term_id in accessor.values_for_doc(*doc) { for term_id in accessor.values_for_doc(*doc) {
let entry = self let entry = self.term_buckets.entries.entry(term_id).or_default();
.term_buckets *entry += 1;
.entries // TODO: check if seperate loop is faster (may depend on the codec)
.entry(term_id) if let Some(blueprint) = self.blueprint.as_ref() {
.or_insert_with(|| TermBucketEntry::from_blueprint(&self.blueprint)); let sub_aggregations = self
entry.doc_count += 1; .term_buckets
if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() { .sub_aggs
.entry(term_id)
.or_insert_with(|| blueprint.clone());
sub_aggregations.collect(*doc, sub_aggregation_accessor)?; sub_aggregations.collect(*doc, sub_aggregation_accessor)?;
} }
} }
@@ -386,15 +395,16 @@ impl SegmentTermCollector {
blueprint, blueprint,
field_type, field_type,
accessor_idx, accessor_idx,
val_cache: Default::default(),
}) })
} }
#[inline]
pub(crate) fn into_intermediate_bucket_result( pub(crate) fn into_intermediate_bucket_result(
self, mut self,
agg_with_accessor: &BucketAggregationWithAccessor, agg_with_accessor: &BucketAggregationWithAccessor,
) -> crate::Result<IntermediateBucketResult> { ) -> crate::Result<IntermediateBucketResult> {
let mut entries: Vec<(u64, TermBucketEntry)> = let mut entries: Vec<(u64, u64)> = self.term_buckets.entries.into_iter().collect();
self.term_buckets.entries.into_iter().collect();
let order_by_sub_aggregation = let order_by_sub_aggregation =
matches!(self.req.order.target, OrderTarget::SubAggregation(_)); matches!(self.req.order.target, OrderTarget::SubAggregation(_));
@@ -417,9 +427,9 @@ impl SegmentTermCollector {
} }
OrderTarget::Count => { OrderTarget::Count => {
if self.req.order.order == Order::Desc { if self.req.order.order == Order::Desc {
entries.sort_unstable_by_key(|bucket| std::cmp::Reverse(bucket.doc_count())); entries.sort_unstable_by_key(|bucket| std::cmp::Reverse(bucket.1));
} else { } else {
entries.sort_unstable_by_key(|bucket| bucket.doc_count()); entries.sort_unstable_by_key(|bucket| bucket.1);
} }
} }
} }
@@ -432,6 +442,33 @@ impl SegmentTermCollector {
let mut dict: FxHashMap<Key, IntermediateTermBucketEntry> = Default::default(); let mut dict: FxHashMap<Key, IntermediateTermBucketEntry> = Default::default();
dict.reserve(entries.len()); dict.reserve(entries.len());
let mut into_intermediate_bucket_entry =
|id, doc_count| -> crate::Result<IntermediateTermBucketEntry> {
let intermediate_entry = if let Some(blueprint) = self.blueprint.as_ref() {
IntermediateTermBucketEntry {
doc_count,
sub_aggregation: self
.term_buckets
.sub_aggs
.remove(&id)
.expect(&format!(
"Internal Error: could not find subaggregation for id {}",
id
))
.into_intermediate_aggregations_result(
&agg_with_accessor.sub_aggregation,
)?,
}
} else {
IntermediateTermBucketEntry {
doc_count,
sub_aggregation: Default::default(),
}
};
Ok(intermediate_entry)
};
if self.field_type == ColumnType::Str { if self.field_type == ColumnType::Str {
let term_dict = agg_with_accessor let term_dict = agg_with_accessor
.str_dict_column .str_dict_column
@@ -439,17 +476,17 @@ impl SegmentTermCollector {
.expect("internal error: term dictionary not found for term aggregation"); .expect("internal error: term dictionary not found for term aggregation");
let mut buffer = String::new(); let mut buffer = String::new();
for (term_id, entry) in entries { for (term_id, doc_count) in entries {
if !term_dict.ord_to_str(term_id, &mut buffer)? { if !term_dict.ord_to_str(term_id, &mut buffer)? {
return Err(TantivyError::InternalError(format!( return Err(TantivyError::InternalError(format!(
"Couldn't find term_id {} in dict", "Couldn't find term_id {} in dict",
term_id term_id
))); )));
} }
dict.insert(
Key::Str(buffer.to_string()), let intermediate_entry = into_intermediate_bucket_entry(term_id, doc_count)?;
entry.into_intermediate_bucket_entry(&agg_with_accessor.sub_aggregation)?,
); dict.insert(Key::Str(buffer.to_string()), intermediate_entry);
} }
if self.req.min_doc_count == 0 { if self.req.min_doc_count == 0 {
// TODO: Handle rev streaming for descending sorting by keys // TODO: Handle rev streaming for descending sorting by keys
@@ -468,12 +505,10 @@ impl SegmentTermCollector {
} }
} }
} else { } else {
for (val, entry) in entries { for (val, doc_count) in entries {
let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?;
let val = f64_from_fastfield_u64(val, &self.field_type); let val = f64_from_fastfield_u64(val, &self.field_type);
dict.insert( dict.insert(Key::F64(val), intermediate_entry);
Key::F64(val),
entry.into_intermediate_bucket_entry(&agg_with_accessor.sub_aggregation)?,
);
} }
}; };
@@ -495,6 +530,11 @@ impl GetDocCount for (u32, TermBucketEntry) {
self.1.doc_count self.1.doc_count
} }
} }
impl GetDocCount for (u64, u64) {
fn doc_count(&self) -> u64 {
self.1
}
}
impl GetDocCount for (u64, TermBucketEntry) { impl GetDocCount for (u64, TermBucketEntry) {
fn doc_count(&self) -> u64 { fn doc_count(&self) -> u64 {
self.1.doc_count self.1.doc_count

View File

@@ -34,6 +34,7 @@ impl BufAggregationCollector {
} }
impl SegmentAggregationCollector for BufAggregationCollector { impl SegmentAggregationCollector for BufAggregationCollector {
#[inline]
fn into_intermediate_aggregations_result( fn into_intermediate_aggregations_result(
self: Box<Self>, self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor, agg_with_accessor: &AggregationsWithAccessor,
@@ -41,6 +42,7 @@ impl SegmentAggregationCollector for BufAggregationCollector {
Box::new(self.collector).into_intermediate_aggregations_result(agg_with_accessor) Box::new(self.collector).into_intermediate_aggregations_result(agg_with_accessor)
} }
#[inline]
fn collect( fn collect(
&mut self, &mut self,
doc: crate::DocId, doc: crate::DocId,
@@ -56,6 +58,7 @@ impl SegmentAggregationCollector for BufAggregationCollector {
Ok(()) Ok(())
} }
#[inline]
fn collect_block( fn collect_block(
&mut self, &mut self,
docs: &[crate::DocId], docs: &[crate::DocId],
@@ -67,6 +70,7 @@ impl SegmentAggregationCollector for BufAggregationCollector {
Ok(()) Ok(())
} }
#[inline]
fn flush(&mut self, agg_with_accessor: &AggregationsWithAccessor) -> crate::Result<()> { fn flush(&mut self, agg_with_accessor: &AggregationsWithAccessor) -> crate::Result<()> {
self.collector self.collector
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_with_accessor)?; .collect_block(&self.staged_docs[..self.num_staged_docs], agg_with_accessor)?;

View File

@@ -156,6 +156,7 @@ pub(crate) struct SegmentStatsCollector {
pub(crate) collecting_for: SegmentStatsType, pub(crate) collecting_for: SegmentStatsType,
pub(crate) stats: IntermediateStats, pub(crate) stats: IntermediateStats,
pub(crate) accessor_idx: usize, pub(crate) accessor_idx: usize,
val_cache: Vec<u64>,
} }
impl SegmentStatsCollector { impl SegmentStatsCollector {
@@ -169,14 +170,16 @@ impl SegmentStatsCollector {
collecting_for, collecting_for,
stats: IntermediateStats::default(), stats: IntermediateStats::default(),
accessor_idx, accessor_idx,
val_cache: Default::default(),
} }
} }
#[inline] #[inline]
pub(crate) fn collect_block_with_field(&mut self, docs: &[DocId], field: &Column<u64>) { pub(crate) fn collect_block_with_field(&mut self, docs: &[DocId], field: &Column<u64>) {
if field.get_cardinality() == Cardinality::Full { if field.get_cardinality() == Cardinality::Full {
for doc in docs { self.val_cache.resize(docs.len(), 0);
let val = field.values.get_val(*doc); field.values.get_vals(docs, &mut self.val_cache);
let val1 = f64_from_fastfield_u64(val, &self.field_type); for val in self.val_cache.iter() {
let val1 = f64_from_fastfield_u64(*val, &self.field_type);
self.stats.collect(val1); self.stats.collect(val1);
} }
} else { } else {
@@ -191,6 +194,7 @@ impl SegmentStatsCollector {
} }
impl SegmentAggregationCollector for SegmentStatsCollector { impl SegmentAggregationCollector for SegmentStatsCollector {
#[inline]
fn into_intermediate_aggregations_result( fn into_intermediate_aggregations_result(
self: Box<Self>, self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor, agg_with_accessor: &AggregationsWithAccessor,
@@ -227,6 +231,7 @@ impl SegmentAggregationCollector for SegmentStatsCollector {
}) })
} }
#[inline]
fn collect( fn collect(
&mut self, &mut self,
doc: crate::DocId, doc: crate::DocId,