mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2025-12-22 18:19:58 +00:00
split term collection count and sub_agg (#1921)
use unrolled ColumnValues::get_vals
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -13,3 +13,5 @@ benchmark
|
||||
.idea
|
||||
trace.dat
|
||||
cargo-timing*
|
||||
control
|
||||
variable
|
||||
|
||||
@@ -58,10 +58,21 @@ pub trait ColumnValues<T: PartialOrd = u64>: Send + Sync {
|
||||
/// # Panics
|
||||
///
|
||||
/// May panic if `idx` is greater than the column length.
|
||||
fn get_vals(&self, idxs: &[u32], output: &mut [T]) {
|
||||
assert!(idxs.len() == output.len());
|
||||
for (out, &idx) in output.iter_mut().zip(idxs.iter()) {
|
||||
*out = self.get_val(idx);
|
||||
fn get_vals(&self, indexes: &[u32], output: &mut [T]) {
|
||||
assert!(indexes.len() == output.len());
|
||||
let out_and_idx_chunks = output.chunks_exact_mut(4).zip(indexes.chunks_exact(4));
|
||||
for (out_x4, idx_x4) in out_and_idx_chunks {
|
||||
out_x4[0] = self.get_val(idx_x4[0]);
|
||||
out_x4[1] = self.get_val(idx_x4[1]);
|
||||
out_x4[2] = self.get_val(idx_x4[2]);
|
||||
out_x4[3] = self.get_val(idx_x4[3]);
|
||||
}
|
||||
|
||||
let step_size = 4;
|
||||
let cutoff = indexes.len() - indexes.len() % step_size;
|
||||
|
||||
for idx in cutoff..indexes.len() {
|
||||
output[idx] = self.get_val(indexes[idx] as u32);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ where
|
||||
Input: PartialOrd + Send + Debug + Sync + Clone,
|
||||
Output: PartialOrd + Send + Debug + Sync + Clone,
|
||||
{
|
||||
#[inline]
|
||||
#[inline(always)]
|
||||
fn get_val(&self, idx: u32) -> Output {
|
||||
let from_val = self.from_column.get_val(idx);
|
||||
self.monotonic_mapping.mapping(from_val)
|
||||
|
||||
@@ -99,14 +99,28 @@ pub(crate) fn create_and_validate<TColumnCodec: ColumnCodec>(
|
||||
|
||||
let reader = TColumnCodec::load(OwnedBytes::new(buffer)).unwrap();
|
||||
assert_eq!(reader.num_vals(), vals.len() as u32);
|
||||
let mut buffer = Vec::new();
|
||||
for (doc, orig_val) in vals.iter().copied().enumerate() {
|
||||
let val = reader.get_val(doc as u32);
|
||||
assert_eq!(
|
||||
val, orig_val,
|
||||
"val `{val}` does not match orig_val {orig_val:?}, in data set {name}, data `{vals:?}`",
|
||||
);
|
||||
|
||||
buffer.resize(1, 0);
|
||||
reader.get_vals(&[doc as u32], &mut buffer);
|
||||
let val = buffer[0];
|
||||
assert_eq!(
|
||||
val, orig_val,
|
||||
"val `{val}` does not match orig_val {orig_val:?}, in data set {name}, data `{vals:?}`",
|
||||
);
|
||||
}
|
||||
|
||||
let all_docs: Vec<u32> = (0..vals.len() as u32).collect();
|
||||
buffer.resize(all_docs.len(), 0);
|
||||
reader.get_vals(&all_docs, &mut buffer);
|
||||
assert_eq!(vals, buffer);
|
||||
|
||||
if !vals.is_empty() {
|
||||
let test_rand_idx = rand::thread_rng().gen_range(0..=vals.len() - 1);
|
||||
let expected_positions: Vec<u32> = vals
|
||||
|
||||
@@ -230,6 +230,7 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
@@ -238,6 +239,7 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
|
||||
self.collect_block(&[doc], agg_with_accessor)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
|
||||
@@ -208,6 +208,7 @@ impl SegmentAggregationCollector for SegmentRangeCollector {
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
@@ -216,6 +217,7 @@ impl SegmentAggregationCollector for SegmentRangeCollector {
|
||||
self.collect_block(&[doc], agg_with_accessor)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
|
||||
@@ -205,7 +205,8 @@ impl TermsAggregationInternal {
|
||||
#[derive(Clone, Debug, Default)]
|
||||
/// Container to store term_ids/or u64 values and their buckets.
|
||||
struct TermBuckets {
|
||||
pub(crate) entries: FxHashMap<u64, TermBucketEntry>,
|
||||
pub(crate) entries: FxHashMap<u64, u64>,
|
||||
pub(crate) sub_aggs: FxHashMap<u64, Box<dyn SegmentAggregationCollector>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
@@ -249,10 +250,8 @@ impl TermBucketEntry {
|
||||
|
||||
impl TermBuckets {
|
||||
fn force_flush(&mut self, agg_with_accessor: &AggregationsWithAccessor) -> crate::Result<()> {
|
||||
for entry in &mut self.entries.values_mut() {
|
||||
if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() {
|
||||
sub_aggregations.flush(agg_with_accessor)?;
|
||||
}
|
||||
for sub_aggregations in &mut self.sub_aggs.values_mut() {
|
||||
sub_aggregations.as_mut().flush(agg_with_accessor)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -268,6 +267,7 @@ pub struct SegmentTermCollector {
|
||||
blueprint: Option<Box<dyn SegmentAggregationCollector>>,
|
||||
field_type: ColumnType,
|
||||
accessor_idx: usize,
|
||||
val_cache: Vec<u64>,
|
||||
}
|
||||
|
||||
pub(crate) fn get_agg_name_and_property(name: &str) -> (&str, &str) {
|
||||
@@ -292,6 +292,7 @@ impl SegmentAggregationCollector for SegmentTermCollector {
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
@@ -300,6 +301,7 @@ impl SegmentAggregationCollector for SegmentTermCollector {
|
||||
self.collect_block(&[doc], agg_with_accessor)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
@@ -310,28 +312,35 @@ impl SegmentAggregationCollector for SegmentTermCollector {
|
||||
&agg_with_accessor.buckets.values[self.accessor_idx].sub_aggregation;
|
||||
|
||||
if accessor.get_cardinality() == Cardinality::Full {
|
||||
for doc in docs {
|
||||
let term_id = accessor.values.get_val(*doc);
|
||||
let entry = self
|
||||
self.val_cache.resize(docs.len(), 0);
|
||||
accessor.values.get_vals(docs, &mut self.val_cache);
|
||||
for (doc, term_id) in docs.iter().zip(self.val_cache.iter().cloned()) {
|
||||
let entry = self.term_buckets.entries.entry(term_id).or_default();
|
||||
*entry += 1;
|
||||
}
|
||||
// has subagg
|
||||
if let Some(blueprint) = self.blueprint.as_ref() {
|
||||
for (doc, term_id) in docs.iter().zip(self.val_cache.iter().cloned()) {
|
||||
let sub_aggregations = self
|
||||
.term_buckets
|
||||
.entries
|
||||
.sub_aggs
|
||||
.entry(term_id)
|
||||
.or_insert_with(|| TermBucketEntry::from_blueprint(&self.blueprint));
|
||||
entry.doc_count += 1;
|
||||
if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() {
|
||||
.or_insert_with(|| blueprint.clone());
|
||||
sub_aggregations.collect(*doc, sub_aggregation_accessor)?;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for doc in docs {
|
||||
for term_id in accessor.values_for_doc(*doc) {
|
||||
let entry = self
|
||||
let entry = self.term_buckets.entries.entry(term_id).or_default();
|
||||
*entry += 1;
|
||||
// TODO: check if seperate loop is faster (may depend on the codec)
|
||||
if let Some(blueprint) = self.blueprint.as_ref() {
|
||||
let sub_aggregations = self
|
||||
.term_buckets
|
||||
.entries
|
||||
.sub_aggs
|
||||
.entry(term_id)
|
||||
.or_insert_with(|| TermBucketEntry::from_blueprint(&self.blueprint));
|
||||
entry.doc_count += 1;
|
||||
if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() {
|
||||
.or_insert_with(|| blueprint.clone());
|
||||
sub_aggregations.collect(*doc, sub_aggregation_accessor)?;
|
||||
}
|
||||
}
|
||||
@@ -386,15 +395,16 @@ impl SegmentTermCollector {
|
||||
blueprint,
|
||||
field_type,
|
||||
accessor_idx,
|
||||
val_cache: Default::default(),
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn into_intermediate_bucket_result(
|
||||
self,
|
||||
mut self,
|
||||
agg_with_accessor: &BucketAggregationWithAccessor,
|
||||
) -> crate::Result<IntermediateBucketResult> {
|
||||
let mut entries: Vec<(u64, TermBucketEntry)> =
|
||||
self.term_buckets.entries.into_iter().collect();
|
||||
let mut entries: Vec<(u64, u64)> = self.term_buckets.entries.into_iter().collect();
|
||||
|
||||
let order_by_sub_aggregation =
|
||||
matches!(self.req.order.target, OrderTarget::SubAggregation(_));
|
||||
@@ -417,9 +427,9 @@ impl SegmentTermCollector {
|
||||
}
|
||||
OrderTarget::Count => {
|
||||
if self.req.order.order == Order::Desc {
|
||||
entries.sort_unstable_by_key(|bucket| std::cmp::Reverse(bucket.doc_count()));
|
||||
entries.sort_unstable_by_key(|bucket| std::cmp::Reverse(bucket.1));
|
||||
} else {
|
||||
entries.sort_unstable_by_key(|bucket| bucket.doc_count());
|
||||
entries.sort_unstable_by_key(|bucket| bucket.1);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -432,6 +442,33 @@ impl SegmentTermCollector {
|
||||
|
||||
let mut dict: FxHashMap<Key, IntermediateTermBucketEntry> = Default::default();
|
||||
dict.reserve(entries.len());
|
||||
|
||||
let mut into_intermediate_bucket_entry =
|
||||
|id, doc_count| -> crate::Result<IntermediateTermBucketEntry> {
|
||||
let intermediate_entry = if let Some(blueprint) = self.blueprint.as_ref() {
|
||||
IntermediateTermBucketEntry {
|
||||
doc_count,
|
||||
sub_aggregation: self
|
||||
.term_buckets
|
||||
.sub_aggs
|
||||
.remove(&id)
|
||||
.expect(&format!(
|
||||
"Internal Error: could not find subaggregation for id {}",
|
||||
id
|
||||
))
|
||||
.into_intermediate_aggregations_result(
|
||||
&agg_with_accessor.sub_aggregation,
|
||||
)?,
|
||||
}
|
||||
} else {
|
||||
IntermediateTermBucketEntry {
|
||||
doc_count,
|
||||
sub_aggregation: Default::default(),
|
||||
}
|
||||
};
|
||||
Ok(intermediate_entry)
|
||||
};
|
||||
|
||||
if self.field_type == ColumnType::Str {
|
||||
let term_dict = agg_with_accessor
|
||||
.str_dict_column
|
||||
@@ -439,17 +476,17 @@ impl SegmentTermCollector {
|
||||
.expect("internal error: term dictionary not found for term aggregation");
|
||||
|
||||
let mut buffer = String::new();
|
||||
for (term_id, entry) in entries {
|
||||
for (term_id, doc_count) in entries {
|
||||
if !term_dict.ord_to_str(term_id, &mut buffer)? {
|
||||
return Err(TantivyError::InternalError(format!(
|
||||
"Couldn't find term_id {} in dict",
|
||||
term_id
|
||||
)));
|
||||
}
|
||||
dict.insert(
|
||||
Key::Str(buffer.to_string()),
|
||||
entry.into_intermediate_bucket_entry(&agg_with_accessor.sub_aggregation)?,
|
||||
);
|
||||
|
||||
let intermediate_entry = into_intermediate_bucket_entry(term_id, doc_count)?;
|
||||
|
||||
dict.insert(Key::Str(buffer.to_string()), intermediate_entry);
|
||||
}
|
||||
if self.req.min_doc_count == 0 {
|
||||
// TODO: Handle rev streaming for descending sorting by keys
|
||||
@@ -468,12 +505,10 @@ impl SegmentTermCollector {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (val, entry) in entries {
|
||||
for (val, doc_count) in entries {
|
||||
let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?;
|
||||
let val = f64_from_fastfield_u64(val, &self.field_type);
|
||||
dict.insert(
|
||||
Key::F64(val),
|
||||
entry.into_intermediate_bucket_entry(&agg_with_accessor.sub_aggregation)?,
|
||||
);
|
||||
dict.insert(Key::F64(val), intermediate_entry);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -495,6 +530,11 @@ impl GetDocCount for (u32, TermBucketEntry) {
|
||||
self.1.doc_count
|
||||
}
|
||||
}
|
||||
impl GetDocCount for (u64, u64) {
|
||||
fn doc_count(&self) -> u64 {
|
||||
self.1
|
||||
}
|
||||
}
|
||||
impl GetDocCount for (u64, TermBucketEntry) {
|
||||
fn doc_count(&self) -> u64 {
|
||||
self.1.doc_count
|
||||
|
||||
@@ -34,6 +34,7 @@ impl BufAggregationCollector {
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for BufAggregationCollector {
|
||||
#[inline]
|
||||
fn into_intermediate_aggregations_result(
|
||||
self: Box<Self>,
|
||||
agg_with_accessor: &AggregationsWithAccessor,
|
||||
@@ -41,6 +42,7 @@ impl SegmentAggregationCollector for BufAggregationCollector {
|
||||
Box::new(self.collector).into_intermediate_aggregations_result(agg_with_accessor)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
@@ -56,6 +58,7 @@ impl SegmentAggregationCollector for BufAggregationCollector {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
@@ -67,6 +70,7 @@ impl SegmentAggregationCollector for BufAggregationCollector {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn flush(&mut self, agg_with_accessor: &AggregationsWithAccessor) -> crate::Result<()> {
|
||||
self.collector
|
||||
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_with_accessor)?;
|
||||
|
||||
@@ -156,6 +156,7 @@ pub(crate) struct SegmentStatsCollector {
|
||||
pub(crate) collecting_for: SegmentStatsType,
|
||||
pub(crate) stats: IntermediateStats,
|
||||
pub(crate) accessor_idx: usize,
|
||||
val_cache: Vec<u64>,
|
||||
}
|
||||
|
||||
impl SegmentStatsCollector {
|
||||
@@ -169,14 +170,16 @@ impl SegmentStatsCollector {
|
||||
collecting_for,
|
||||
stats: IntermediateStats::default(),
|
||||
accessor_idx,
|
||||
val_cache: Default::default(),
|
||||
}
|
||||
}
|
||||
#[inline]
|
||||
pub(crate) fn collect_block_with_field(&mut self, docs: &[DocId], field: &Column<u64>) {
|
||||
if field.get_cardinality() == Cardinality::Full {
|
||||
for doc in docs {
|
||||
let val = field.values.get_val(*doc);
|
||||
let val1 = f64_from_fastfield_u64(val, &self.field_type);
|
||||
self.val_cache.resize(docs.len(), 0);
|
||||
field.values.get_vals(docs, &mut self.val_cache);
|
||||
for val in self.val_cache.iter() {
|
||||
let val1 = f64_from_fastfield_u64(*val, &self.field_type);
|
||||
self.stats.collect(val1);
|
||||
}
|
||||
} else {
|
||||
@@ -191,6 +194,7 @@ impl SegmentStatsCollector {
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for SegmentStatsCollector {
|
||||
#[inline]
|
||||
fn into_intermediate_aggregations_result(
|
||||
self: Box<Self>,
|
||||
agg_with_accessor: &AggregationsWithAccessor,
|
||||
@@ -227,6 +231,7 @@ impl SegmentAggregationCollector for SegmentStatsCollector {
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
|
||||
Reference in New Issue
Block a user