From e510f699c8e261a10f30cded0330aa25a82025c1 Mon Sep 17 00:00:00 2001 From: PSeitz Date: Mon, 27 Feb 2023 15:04:41 +0800 Subject: [PATCH] feat: add support for u64,i64,f64 fields in term aggregation (#1883) * feat: add support for u64,i64,f64 fields in term aggregation * hash enum values * fix build * Apply suggestions from code review Co-authored-by: Paul Masurel --------- Co-authored-by: Paul Masurel --- src/aggregation/bucket/term_agg.rs | 204 ++++++++++++++++----- src/aggregation/intermediate_agg_result.rs | 25 ++- src/aggregation/mod.rs | 24 ++- src/aggregation/segment_agg_result.rs | 1 + 4 files changed, 205 insertions(+), 49 deletions(-) diff --git a/src/aggregation/bucket/term_agg.rs b/src/aggregation/bucket/term_agg.rs index 6f884c1ac..f7f991a4c 100644 --- a/src/aggregation/bucket/term_agg.rs +++ b/src/aggregation/bucket/term_agg.rs @@ -1,6 +1,6 @@ use std::fmt::Debug; -use columnar::Cardinality; +use columnar::{Cardinality, ColumnType}; use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; @@ -15,7 +15,7 @@ use crate::aggregation::intermediate_agg_result::{ use crate::aggregation::segment_agg_result::{ build_segment_agg_collector, SegmentAggregationCollector, }; -use crate::aggregation::VecWithNames; +use crate::aggregation::{f64_from_fastfield_u64, Key, VecWithNames}; use crate::error::DataCorruption; use crate::TantivyError; @@ -25,6 +25,10 @@ use crate::TantivyError; /// If the text is untokenized and single value, that means one term per document and therefore it /// is in fact doc count. /// +/// ## Prerequisite +/// Term aggregations work only on [fast fields](`crate::fastfield`) of type `u64`, `f64`, `i64` and +/// text. +/// /// ### Terminology /// Shard parameters are supposed to be equivalent to elasticsearch shard parameter. /// Since they are @@ -199,9 +203,9 @@ impl TermsAggregationInternal { } #[derive(Clone, Debug, Default)] -/// Container to store term_ids and their buckets. +/// Container to store term_ids/or u64 values and their buckets. struct TermBuckets { - pub(crate) entries: FxHashMap, + pub(crate) entries: FxHashMap, } #[derive(Clone, Default)] @@ -262,6 +266,7 @@ pub struct SegmentTermCollector { term_buckets: TermBuckets, req: TermsAggregationInternal, blueprint: Option>, + field_type: ColumnType, accessor_idx: usize, } @@ -310,7 +315,7 @@ impl SegmentAggregationCollector for SegmentTermCollector { let entry = self .term_buckets .entries - .entry(term_id as u32) + .entry(term_id) .or_insert_with(|| TermBucketEntry::from_blueprint(&self.blueprint)); entry.doc_count += 1; if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() { @@ -323,7 +328,7 @@ impl SegmentAggregationCollector for SegmentTermCollector { let entry = self .term_buckets .entries - .entry(term_id as u32) + .entry(term_id) .or_insert_with(|| TermBucketEntry::from_blueprint(&self.blueprint)); entry.doc_count += 1; if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() { @@ -348,6 +353,7 @@ impl SegmentTermCollector { pub(crate) fn from_req_and_validate( req: &TermsAggregation, sub_aggregations: &AggregationsWithAccessor, + field_type: ColumnType, accessor_idx: usize, ) -> crate::Result { let term_buckets = TermBuckets::default(); @@ -378,6 +384,7 @@ impl SegmentTermCollector { req: TermsAggregationInternal::from_req(req), term_buckets, blueprint, + field_type, accessor_idx, }) } @@ -386,7 +393,7 @@ impl SegmentTermCollector { self, agg_with_accessor: &BucketAggregationWithAccessor, ) -> crate::Result { - let mut entries: Vec<(u32, TermBucketEntry)> = + let mut entries: Vec<(u64, TermBucketEntry)> = self.term_buckets.entries.into_iter().collect(); let order_by_sub_aggregation = @@ -423,41 +430,52 @@ impl SegmentTermCollector { cut_off_buckets(&mut entries, self.req.segment_size as usize) }; - let mut dict: FxHashMap = Default::default(); + let mut dict: FxHashMap = Default::default(); + dict.reserve(entries.len()); + if self.field_type == ColumnType::Str { + let term_dict = agg_with_accessor + .str_dict_column + .as_ref() + .expect("internal error: term dictionary not found for term aggregation"); - let str_column = agg_with_accessor - .str_dict_column - .as_ref() - .expect("Missing str column"); //< TODO Fixme - - let mut buffer = String::new(); - for (term_id, entry) in entries { - if !str_column.ord_to_str(term_id as u64, &mut buffer)? { - return Err(TantivyError::InternalError(format!( - "Couldn't find term_id {} in dict", - term_id - ))); - } - dict.insert( - buffer.to_string(), - entry.into_intermediate_bucket_entry(&agg_with_accessor.sub_aggregation)?, - ); - } - if self.req.min_doc_count == 0 { - // TODO: Handle rev streaming for descending sorting by keys - let mut stream = str_column.dictionary().stream()?; - while let Some((key, _ord)) = stream.next() { - if dict.len() >= self.req.segment_size as usize { - break; + let mut buffer = String::new(); + for (term_id, entry) in entries { + if !term_dict.ord_to_str(term_id, &mut buffer)? { + return Err(TantivyError::InternalError(format!( + "Couldn't find term_id {} in dict", + term_id + ))); } + dict.insert( + Key::Str(buffer.to_string()), + entry.into_intermediate_bucket_entry(&agg_with_accessor.sub_aggregation)?, + ); + } + if self.req.min_doc_count == 0 { + // TODO: Handle rev streaming for descending sorting by keys + let mut stream = term_dict.dictionary().stream()?; + while let Some((key, _ord)) = stream.next() { + if dict.len() >= self.req.segment_size as usize { + break; + } - let key = std::str::from_utf8(key) - .map_err(|utf8_err| DataCorruption::comment_only(utf8_err.to_string()))?; - if !dict.contains_key(key) { - dict.insert(key.to_owned(), Default::default()); + let key = Key::Str( + std::str::from_utf8(key) + .map_err(|utf8_err| DataCorruption::comment_only(utf8_err.to_string()))? + .to_string(), + ); + dict.entry(key).or_default(); } } - } + } else { + for (val, entry) in entries { + let val = f64_from_fastfield_u64(val, &self.field_type); + dict.insert( + Key::F64(val), + entry.into_intermediate_bucket_entry(&agg_with_accessor.sub_aggregation)?, + ); + } + }; Ok(IntermediateBucketResult::Terms( IntermediateTermBucketResult { @@ -477,6 +495,11 @@ impl GetDocCount for (u32, TermBucketEntry) { self.1.doc_count } } +impl GetDocCount for (u64, TermBucketEntry) { + fn doc_count(&self) -> u64 { + self.1.doc_count + } +} impl GetDocCount for (String, IntermediateTermBucketEntry) { fn doc_count(&self) -> u64 { self.1.doc_count @@ -620,7 +643,8 @@ mod tests { fn terms_aggregation_test_order_count_merge_segment(merge_segments: bool) -> crate::Result<()> { let segment_and_terms = vec![ vec![(5.0, "terma".to_string())], - vec![(4.0, "termb".to_string())], + vec![(2.0, "termb".to_string())], + vec![(2.0, "terma".to_string())], vec![(1.0, "termc".to_string())], vec![(1.0, "termc".to_string())], vec![(1.0, "termc".to_string())], @@ -661,7 +685,7 @@ mod tests { }), ..Default::default() }), - sub_aggregation: sub_agg, + sub_aggregation: sub_agg.clone(), }), )] .into_iter() @@ -670,18 +694,114 @@ mod tests { let res = exec_request(agg_req, &index)?; assert_eq!(res["my_texts"]["buckets"][0]["key"], "termb"); assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 2); - assert_eq!(res["my_texts"]["buckets"][0]["avg_score"]["value"], 6.0); + assert_eq!(res["my_texts"]["buckets"][0]["avg_score"]["value"], 5.0); assert_eq!(res["my_texts"]["buckets"][1]["key"], "termc"); assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 3); assert_eq!(res["my_texts"]["buckets"][1]["avg_score"]["value"], 1.0); assert_eq!(res["my_texts"]["buckets"][2]["key"], "terma"); - assert_eq!(res["my_texts"]["buckets"][2]["doc_count"], 5); - assert_eq!(res["my_texts"]["buckets"][2]["avg_score"]["value"], 5.0); + assert_eq!(res["my_texts"]["buckets"][2]["doc_count"], 6); + assert_eq!(res["my_texts"]["buckets"][2]["avg_score"]["value"], 4.5); assert_eq!(res["my_texts"]["sum_other_doc_count"], 0); + // Agg on non string + // + let agg_req: Aggregations = vec![ + ( + "my_scores1".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Terms(TermsAggregation { + field: "score".to_string(), + order: Some(CustomOrder { + order: Order::Asc, + target: OrderTarget::Count, + }), + ..Default::default() + }), + sub_aggregation: sub_agg.clone(), + }), + ), + ( + "my_scores2".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Terms(TermsAggregation { + field: "score_f64".to_string(), + order: Some(CustomOrder { + order: Order::Asc, + target: OrderTarget::Count, + }), + ..Default::default() + }), + sub_aggregation: sub_agg.clone(), + }), + ), + ( + "my_scores3".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Terms(TermsAggregation { + field: "score_i64".to_string(), + order: Some(CustomOrder { + order: Order::Asc, + target: OrderTarget::Count, + }), + ..Default::default() + }), + sub_aggregation: sub_agg, + }), + ), + ] + .into_iter() + .collect(); + + let res = exec_request(agg_req, &index)?; + assert_eq!(res["my_scores1"]["buckets"][0]["key"], 8.0); + assert_eq!(res["my_scores1"]["buckets"][0]["doc_count"], 1); + assert_eq!(res["my_scores1"]["buckets"][0]["avg_score"]["value"], 8.0); + + assert_eq!(res["my_scores1"]["buckets"][1]["key"], 2.0); + assert_eq!(res["my_scores1"]["buckets"][1]["doc_count"], 2); + assert_eq!(res["my_scores1"]["buckets"][1]["avg_score"]["value"], 2.0); + + assert_eq!(res["my_scores1"]["buckets"][2]["key"], 1.0); + assert_eq!(res["my_scores1"]["buckets"][2]["doc_count"], 3); + assert_eq!(res["my_scores1"]["buckets"][2]["avg_score"]["value"], 1.0); + + assert_eq!(res["my_scores1"]["buckets"][3]["key"], 5.0); + assert_eq!(res["my_scores1"]["buckets"][3]["doc_count"], 5); + assert_eq!(res["my_scores1"]["buckets"][3]["avg_score"]["value"], 5.0); + + assert_eq!(res["my_scores1"]["sum_other_doc_count"], 0); + + assert_eq!(res["my_scores2"]["buckets"][0]["key"], 8.0); + assert_eq!(res["my_scores2"]["buckets"][0]["doc_count"], 1); + assert_eq!(res["my_scores2"]["buckets"][0]["avg_score"]["value"], 8.0); + + assert_eq!(res["my_scores2"]["buckets"][1]["key"], 2.0); + assert_eq!(res["my_scores2"]["buckets"][1]["doc_count"], 2); + assert_eq!(res["my_scores2"]["buckets"][1]["avg_score"]["value"], 2.0); + + assert_eq!(res["my_scores2"]["buckets"][2]["key"], 1.0); + assert_eq!(res["my_scores2"]["buckets"][2]["doc_count"], 3); + assert_eq!(res["my_scores2"]["buckets"][2]["avg_score"]["value"], 1.0); + + assert_eq!(res["my_scores2"]["sum_other_doc_count"], 0); + + assert_eq!(res["my_scores3"]["buckets"][0]["key"], 8.0); + assert_eq!(res["my_scores3"]["buckets"][0]["doc_count"], 1); + assert_eq!(res["my_scores3"]["buckets"][0]["avg_score"]["value"], 8.0); + + assert_eq!(res["my_scores3"]["buckets"][1]["key"], 2.0); + assert_eq!(res["my_scores3"]["buckets"][1]["doc_count"], 2); + assert_eq!(res["my_scores3"]["buckets"][1]["avg_score"]["value"], 2.0); + + assert_eq!(res["my_scores3"]["buckets"][2]["key"], 1.0); + assert_eq!(res["my_scores3"]["buckets"][2]["doc_count"], 3); + assert_eq!(res["my_scores3"]["buckets"][2]["avg_score"]["value"], 1.0); + + assert_eq!(res["my_scores3"]["sum_other_doc_count"], 0); + Ok(()) } diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index 5f26d627b..b07b02f02 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -373,7 +373,7 @@ impl IntermediateBucketResult { IntermediateBucketResult::Terms(term_res_left), IntermediateBucketResult::Terms(term_res_right), ) => { - merge_maps(&mut term_res_left.entries, term_res_right.entries); + merge_key_maps(&mut term_res_left.entries, term_res_right.entries); term_res_left.sum_other_doc_count += term_res_right.sum_other_doc_count; term_res_left.doc_count_error_upper_bound += term_res_right.doc_count_error_upper_bound; @@ -383,7 +383,7 @@ impl IntermediateBucketResult { IntermediateBucketResult::Range(range_res_left), IntermediateBucketResult::Range(range_res_right), ) => { - merge_maps(&mut range_res_left.buckets, range_res_right.buckets); + merge_serialized_key_maps(&mut range_res_left.buckets, range_res_right.buckets); } ( IntermediateBucketResult::Histogram { @@ -435,7 +435,7 @@ pub struct IntermediateRangeBucketResult { #[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)] /// Term aggregation including error counts pub struct IntermediateTermBucketResult { - pub(crate) entries: FxHashMap, + pub(crate) entries: FxHashMap, pub(crate) sum_other_doc_count: u64, pub(crate) doc_count_error_upper_bound: u64, } @@ -454,7 +454,7 @@ impl IntermediateTermBucketResult { .map(|(key, entry)| { Ok(BucketEntry { key_as_string: None, - key: Key::Str(key), + key, doc_count: entry.doc_count, sub_aggregation: entry .sub_aggregation @@ -532,7 +532,7 @@ trait MergeFruits { fn merge_fruits(&mut self, other: Self); } -fn merge_maps( +fn merge_serialized_key_maps( entries_left: &mut FxHashMap, mut entries_right: FxHashMap, ) { @@ -547,6 +547,21 @@ fn merge_maps( } } +fn merge_key_maps( + entries_left: &mut FxHashMap, + mut entries_right: FxHashMap, +) { + for (name, entry_left) in entries_left.iter_mut() { + if let Some(entry_right) = entries_right.remove(name) { + entry_left.merge_fruits(entry_right); + } + } + + for (key, res) in entries_right.into_iter() { + entries_left.entry(key).or_insert(res); + } +} + /// This is the histogram entry for a bucket, which contains a key, count, and optionally /// sub_aggregations. #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index 24a02f19c..f3627499f 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -10,7 +10,7 @@ //! There are two categories: [Metrics](metric) and [Buckets](bucket). //! //! ## Prerequisite -//! Currently aggregations work only on [fast fields](`crate::fastfield`). Single value fast fields +//! Currently aggregations work only on [fast fields](`crate::fastfield`). Fast fields //! of type `u64`, `f64`, `i64`, `date` and fast fields on text fields. //! //! ## Usage @@ -262,7 +262,7 @@ impl VecWithNames { /// The serialized key is used in a `HashMap`. pub type SerializedKey = String; -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, PartialOrd)] +#[derive(Clone, Debug, Serialize, Deserialize, PartialOrd)] /// The key to identify a bucket. #[serde(untagged)] pub enum Key { @@ -271,6 +271,26 @@ pub enum Key { /// `f64` key F64(f64), } +impl Eq for Key {} +impl std::hash::Hash for Key { + fn hash(&self, state: &mut H) { + core::mem::discriminant(self).hash(state); + match self { + Key::Str(text) => text.hash(state), + Key::F64(val) => val.to_bits().hash(state), + } + } +} + +impl PartialEq for Key { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::Str(l), Self::Str(r)) => l == r, + (Self::F64(l), Self::F64(r)) => l == r, + _ => false, + } + } +} impl Display for Key { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { diff --git a/src/aggregation/segment_agg_result.rs b/src/aggregation/segment_agg_result.rs index d2fab0a33..d88c6f8a7 100644 --- a/src/aggregation/segment_agg_result.rs +++ b/src/aggregation/segment_agg_result.rs @@ -150,6 +150,7 @@ pub(crate) fn build_bucket_segment_agg_collector( SegmentTermCollector::from_req_and_validate( terms_req, &req.sub_aggregation, + req.field_type, accessor_idx, )?, )),