diff --git a/src/aggregation/agg_result.rs b/src/aggregation/agg_result.rs index df325e8fa..d83673c5e 100644 --- a/src/aggregation/agg_result.rs +++ b/src/aggregation/agg_result.rs @@ -87,6 +87,10 @@ pub enum BucketResult { /// sub_aggregations. Histogram { /// The buckets. + /// + /// If there are holes depends on the request, if min_doc_count is 0, then there are no + /// holes between the first and last bucket. + /// See [HistogramAggregation](super::agg_request::HistogramAggregation) buckets: Vec, }, } diff --git a/src/aggregation/bucket/histogram/histogram.rs b/src/aggregation/bucket/histogram/histogram.rs index 1ff49b49a..9d8bb51cd 100644 --- a/src/aggregation/bucket/histogram/histogram.rs +++ b/src/aggregation/bucket/histogram/histogram.rs @@ -1,3 +1,5 @@ +use std::fmt::Display; + use serde::{Deserialize, Serialize}; use crate::aggregation::agg_req_with_accessor::{ @@ -27,7 +29,7 @@ use crate::{DocId, TantivyError}; /// Setting min_doc_count to != 0 will filter empty buckets. /// /// The value range of the buckets can bet extended via -/// [extended_bounds](HistogramAggregation::extended_bounds) or set to a predefined range via +/// [extended_bounds](HistogramAggregation::extended_bounds) or limit the range via /// [hard_bounds](HistogramAggregation::hard_bounds). /// /// # Result @@ -68,8 +70,12 @@ pub struct HistogramAggregation { pub offset: Option, /// The minimum number of documents in a bucket to be returned. Defaults to 0. pub min_doc_count: Option, - /// Sets a hard limit for the data range. + /// Limit the data range. + /// /// This can be used to filter values if they are not in the data range. + /// + /// hard_bounds only limits the buckets, to force a range set both extended_bounds and + /// hard_bounds to the same range. pub hard_bounds: Option, /// Can be set to extend your bounds. The range of the buckets is by default defined by the /// data range of the values of the documents. As the name suggests, this can only be used to @@ -95,6 +101,17 @@ impl HistogramAggregation { )); } + if let (Some(hard_bounds), Some(extended_bounds)) = (self.hard_bounds, self.extended_bounds) + { + if extended_bounds.min < hard_bounds.min || extended_bounds.max > hard_bounds.max { + return Err(TantivyError::InvalidArgument(format!( + "extended_bounds have to be inside hard_bounds, extended_bounds: {}, \ + hard_bounds {}", + extended_bounds, hard_bounds + ))); + } + } + Ok(()) } @@ -113,6 +130,12 @@ pub struct HistogramBounds { pub max: f64, } +impl Display for HistogramBounds { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!("[{},{}]", self.min, self.max)) + } +} + impl HistogramBounds { fn contains(&self, val: f64) -> bool { val >= self.min && val <= self.max @@ -359,8 +382,8 @@ fn get_req_min_max(req: &HistogramAggregation, mut min: f64, mut max: f64) -> (f max = max.max(extended_bounds.max); } if let Some(hard_bounds) = &req.hard_bounds { - min = hard_bounds.min; - max = hard_bounds.max; + min = min.max(hard_bounds.min); + max = max.min(hard_bounds.max); } (min, max) @@ -461,7 +484,7 @@ fn generate_buckets_test() { let buckets = generate_buckets(&histogram_req, 0.0, 10.0); assert_eq!(buckets, vec![2.0, 4.0]); - // With hard_bounds extending + // With hard_bounds, extending has no effect let histogram_req = HistogramAggregation { field: "dummy".to_string(), interval: 2.0, @@ -473,7 +496,7 @@ fn generate_buckets_test() { }; let buckets = generate_buckets(&histogram_req, 2.5, 5.5); - assert_eq!(buckets, vec![0.0, 2.0, 4.0, 6.0, 8.0, 10.0]); + assert_eq!(buckets, vec![2.0, 4.0]); // Blubber let histogram_req = HistogramAggregation { @@ -489,6 +512,7 @@ fn generate_buckets_test() { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use serde_json::Value; use super::*; @@ -508,7 +532,7 @@ mod tests { let reader = index.reader()?; let searcher = reader.searcher(); - let agg_res = searcher.search(&AllQuery, &collector).unwrap(); + let agg_res = searcher.search(&AllQuery, &collector)?; let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?; Ok(res) @@ -761,12 +785,80 @@ mod tests { let res = exec_request(agg_req, &index)?; + assert_eq!(res["histogram"]["buckets"][0]["key"], 10.0); + assert_eq!(res["histogram"]["buckets"][0]["doc_count"], 3); + assert_eq!(res["histogram"]["buckets"][1]["key"], 11.0); + assert_eq!(res["histogram"]["buckets"][1]["doc_count"], 0); + assert_eq!(res["histogram"]["buckets"][2]["key"], 12.0); + assert_eq!(res["histogram"]["buckets"][2]["doc_count"], 2); + + assert_eq!(res["histogram"]["buckets"][3], Value::Null); + + // hard_bounds and extended_bounds will act like a force bounds + // + let agg_req: Aggregations = vec![( + "histogram".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Histogram(HistogramAggregation { + field: "score_f64".to_string(), + interval: 1.0, + hard_bounds: Some(HistogramBounds { + min: 2.0, + max: 12.0, + }), + extended_bounds: Some(HistogramBounds { + min: 2.0, + max: 12.0, + }), + ..Default::default() + }), + sub_aggregation: Default::default(), + }), + )] + .into_iter() + .collect(); + + let res = exec_request(agg_req, &index)?; + assert_eq!(res["histogram"]["buckets"][0]["key"], 2.0); assert_eq!(res["histogram"]["buckets"][0]["doc_count"], 0); + assert_eq!(res["histogram"]["buckets"][1]["key"], 3.0); + assert_eq!(res["histogram"]["buckets"][1]["doc_count"], 0); assert_eq!(res["histogram"]["buckets"][10]["key"], 12.0); assert_eq!(res["histogram"]["buckets"][10]["doc_count"], 2); + assert_eq!(res["histogram"]["buckets"][11], Value::Null); + // Invalid request + let agg_req: Aggregations = vec![( + "histogram".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Histogram(HistogramAggregation { + field: "score_f64".to_string(), + interval: 1.0, + hard_bounds: Some(HistogramBounds { + min: 2.0, + max: 12.0, + }), + extended_bounds: Some(HistogramBounds { + min: 1.0, + max: 12.0, + }), + ..Default::default() + }), + sub_aggregation: Default::default(), + }), + )] + .into_iter() + .collect(); + + let res = exec_request(agg_req, &index).unwrap_err(); + assert_eq!( + res.to_string(), + "An invalid argument was passed: 'extended_bounds have to be inside hard_bounds, \ + extended_bounds: [1,12], hard_bounds [2,12]'" + ); + Ok(()) } diff --git a/src/aggregation/collector.rs b/src/aggregation/collector.rs index 4b109bddc..36c74c38e 100644 --- a/src/aggregation/collector.rs +++ b/src/aggregation/collector.rs @@ -122,7 +122,7 @@ impl AggregationSegmentCollector { let result = SegmentAggregationResultsCollector::from_req_and_validate(&aggs_with_accessor)?; Ok(AggregationSegmentCollector { - aggs: aggs_with_accessor.into(), + aggs: aggs_with_accessor, result, }) }