diff --git a/examples/aggregation.rs b/examples/aggregation.rs index ae11dc5a5..fb0d131c1 100644 --- a/examples/aggregation.rs +++ b/examples/aggregation.rs @@ -110,6 +110,7 @@ fn main() -> tantivy::Result<()> { (9f64..14f64).into(), (14f64..20f64).into(), ], + ..Default::default() }), sub_aggregation: sub_agg_req_1.clone(), }), diff --git a/src/aggregation/agg_req.rs b/src/aggregation/agg_req.rs index 898c816ad..97118b7cf 100644 --- a/src/aggregation/agg_req.rs +++ b/src/aggregation/agg_req.rs @@ -20,6 +20,7 @@ //! bucket_agg: BucketAggregationType::Range(RangeAggregation{ //! field: "score".to_string(), //! ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()], +//! keyed: None, //! }), //! sub_aggregation: Default::default(), //! }), @@ -100,6 +101,12 @@ pub(crate) struct BucketAggregationInternal { } impl BucketAggregationInternal { + pub(crate) fn as_range(&self) -> Option<&RangeAggregation> { + match &self.bucket_agg { + BucketAggregationType::Range(range) => Some(range), + _ => None, + } + } pub(crate) fn as_histogram(&self) -> Option<&HistogramAggregation> { match &self.bucket_agg { BucketAggregationType::Histogram(histogram) => Some(histogram), @@ -264,6 +271,7 @@ mod tests { (7f64..20f64).into(), (20f64..f64::MAX).into(), ], + ..Default::default() }), sub_aggregation: Default::default(), }), @@ -312,6 +320,7 @@ mod tests { (7f64..20f64).into(), (20f64..f64::MAX).into(), ], + ..Default::default() }), sub_aggregation: Default::default(), }), @@ -337,6 +346,7 @@ mod tests { (7f64..20f64).into(), (20f64..f64::MAX).into(), ], + ..Default::default() }), sub_aggregation: agg_req2, }), diff --git a/src/aggregation/agg_req_with_accessor.rs b/src/aggregation/agg_req_with_accessor.rs index 491faf213..0c84f99e3 100644 --- a/src/aggregation/agg_req_with_accessor.rs +++ b/src/aggregation/agg_req_with_accessor.rs @@ -77,8 +77,7 @@ impl BucketAggregationWithAccessor { let mut inverted_index = None; let (accessor, field_type) = match &bucket { BucketAggregationType::Range(RangeAggregation { - field: field_name, - ranges: _, + field: field_name, .. }) => get_ff_reader_and_validate(reader, field_name, Cardinality::SingleValue)?, BucketAggregationType::Histogram(HistogramAggregation { field: field_name, .. diff --git a/src/aggregation/agg_result.rs b/src/aggregation/agg_result.rs index fc614990b..f004b7a51 100644 --- a/src/aggregation/agg_result.rs +++ b/src/aggregation/agg_result.rs @@ -104,7 +104,7 @@ pub enum BucketResult { /// sub_aggregations. Range { /// The range buckets sorted by range. - buckets: Vec, + buckets: BucketEntries, }, /// This is the histogram entry for a bucket, which contains a key, count, and optionally /// sub_aggregations. @@ -114,7 +114,7 @@ pub enum BucketResult { /// If there are holes depends on the request, if min_doc_count is 0, then there are no /// holes between the first and last bucket. /// See [HistogramAggregation](super::bucket::HistogramAggregation) - buckets: Vec, + buckets: BucketEntries, }, /// This is the term result Terms { @@ -137,6 +137,17 @@ impl BucketResult { } } +/// This is the wrapper of buckets entries, which can be vector or hashmap +/// depending on if it's keyed or not. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[serde(untagged)] +pub enum BucketEntries { + /// Vector format bucket entries + Vec(Vec), + /// HashMap format bucket entries + HashMap(HashMap), +} + /// This is the default entry for a bucket, which contains a key, count, and optionally /// sub_aggregations. /// diff --git a/src/aggregation/bucket/histogram/histogram.rs b/src/aggregation/bucket/histogram/histogram.rs index 70acf0f11..c26417513 100644 --- a/src/aggregation/bucket/histogram/histogram.rs +++ b/src/aggregation/bucket/histogram/histogram.rs @@ -117,6 +117,9 @@ pub struct HistogramAggregation { /// Cannot be set in conjunction with min_doc_count > 0, since the empty buckets from extended /// bounds would not be returned. pub extended_bounds: Option, + /// Whether to return the buckets as a hash map + #[serde(skip_serializing_if = "Option::is_none")] + pub keyed: Option, } impl HistogramAggregation { diff --git a/src/aggregation/bucket/range.rs b/src/aggregation/bucket/range.rs index 7faa500e7..0aa48bc8f 100644 --- a/src/aggregation/bucket/range.rs +++ b/src/aggregation/bucket/range.rs @@ -35,8 +35,6 @@ use crate::{DocId, TantivyError}; /// # Limitations/Compatibility /// Overlapping ranges are not yet supported. /// -/// The keyed parameter (elasticsearch) is not yet supported. -/// /// # Request JSON Format /// ```json /// { @@ -51,13 +49,16 @@ use crate::{DocId, TantivyError}; /// } /// } /// ``` -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)] pub struct RangeAggregation { /// The field to aggregate on. pub field: String, /// Note that this aggregation includes the from value and excludes the to value for each /// range. Extra buckets will be created until the first to, and last from, if necessary. pub ranges: Vec, + /// Whether to return the buckets as a hash map + #[serde(skip_serializing_if = "Option::is_none")] + pub keyed: Option, } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] @@ -406,6 +407,7 @@ mod tests { let req = RangeAggregation { field: "dummy".to_string(), ranges, + ..Default::default() }; SegmentRangeCollector::from_req_and_validate( @@ -427,6 +429,7 @@ mod tests { bucket_agg: BucketAggregationType::Range(RangeAggregation { field: "fraction_f64".to_string(), ranges: vec![(0f64..0.1f64).into(), (0.1f64..0.2f64).into()], + ..Default::default() }), sub_aggregation: Default::default(), }), diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index b6b38bfde..222225a88 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -21,7 +21,7 @@ use super::bucket::{ use super::metric::{IntermediateAverage, IntermediateStats}; use super::segment_agg_result::SegmentMetricResultCollector; use super::{Key, SerializedKey, VecWithNames}; -use crate::aggregation::agg_result::{AggregationResults, BucketEntry}; +use crate::aggregation::agg_result::{AggregationResults, BucketEntries, BucketEntry}; use crate::aggregation::bucket::TermsAggregationInternal; /// Contains the intermediate aggregation result, which is optimized to be merged with other @@ -281,6 +281,21 @@ impl IntermediateBucketResult { .unwrap_or(f64::MIN) .total_cmp(&right.from.unwrap_or(f64::MIN)) }); + + let is_keyed = req + .as_range() + .expect("unexpected aggregation, expected range aggregation") + .keyed + .is_some(); + let buckets = if is_keyed { + let mut bucket_map = HashMap::new(); + for bucket in buckets { + bucket_map.insert(bucket.key.to_string(), bucket); + } + BucketEntries::HashMap(bucket_map) + } else { + BucketEntries::Vec(buckets) + }; Ok(BucketResult::Range { buckets }) } IntermediateBucketResult::Histogram { buckets } => { @@ -291,6 +306,15 @@ impl IntermediateBucketResult { &req.sub_aggregation, )?; + let buckets = if req.as_histogram().unwrap().keyed.is_some() { + let mut bucket_map = HashMap::new(); + for bucket in buckets { + bucket_map.insert(bucket.key.to_string(), bucket); + } + BucketEntries::HashMap(bucket_map) + } else { + BucketEntries::Vec(buckets) + }; Ok(BucketResult::Histogram { buckets }) } IntermediateBucketResult::Terms(terms) => terms.into_final_result( diff --git a/src/aggregation/metric/stats.rs b/src/aggregation/metric/stats.rs index 2f704b17d..edfec4788 100644 --- a/src/aggregation/metric/stats.rs +++ b/src/aggregation/metric/stats.rs @@ -285,6 +285,7 @@ mod tests { (7f64..19f64).into(), (19f64..20f64).into(), ], + ..Default::default() }), sub_aggregation: iter::once(( "stats".to_string(), diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index 69ae782db..658c9c41e 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -132,6 +132,7 @@ //! bucket_agg: BucketAggregationType::Range(RangeAggregation{ //! field: "score".to_string(), //! ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()], +//! keyed: None, //! }), //! sub_aggregation: sub_agg_req_1.clone(), //! }), @@ -765,6 +766,7 @@ mod tests { bucket_agg: BucketAggregationType::Range(RangeAggregation { field: "score".to_string(), ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()], + ..Default::default() }), sub_aggregation: Default::default(), }), @@ -775,6 +777,7 @@ mod tests { bucket_agg: BucketAggregationType::Range(RangeAggregation { field: "score_f64".to_string(), ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()], + ..Default::default() }), sub_aggregation: Default::default(), }), @@ -785,6 +788,7 @@ mod tests { bucket_agg: BucketAggregationType::Range(RangeAggregation { field: "score_i64".to_string(), ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()], + ..Default::default() }), sub_aggregation: Default::default(), }), @@ -941,6 +945,7 @@ mod tests { (7f64..19f64).into(), (19f64..20f64).into(), ], + ..Default::default() }), sub_aggregation: sub_agg_req.clone(), }), @@ -955,6 +960,7 @@ mod tests { (7f64..19f64).into(), (19f64..20f64).into(), ], + ..Default::default() }), sub_aggregation: sub_agg_req.clone(), }), @@ -969,6 +975,7 @@ mod tests { (7f64..19f64).into(), (19f64..20f64).into(), ], + ..Default::default() }), sub_aggregation: sub_agg_req, }), @@ -1416,6 +1423,7 @@ mod tests { (40000f64..50000f64).into(), (50000f64..60000f64).into(), ], + ..Default::default() }), sub_aggregation: Default::default(), }), @@ -1575,6 +1583,7 @@ mod tests { (7000f64..20000f64).into(), (20000f64..60000f64).into(), ], + ..Default::default() }), sub_aggregation: sub_agg_req_1.clone(), }),