diff --git a/src/aggregation/bucket/range.rs b/src/aggregation/bucket/range.rs index fe936004d..e32a8eb4e 100644 --- a/src/aggregation/bucket/range.rs +++ b/src/aggregation/bucket/range.rs @@ -18,7 +18,7 @@ use crate::{DocId, TantivyError}; /// Provide user-defined buckets to aggregate on. /// Two special buckets will automatically be created to cover the whole range of values. -/// The provided buckets have to be continous. +/// The provided buckets have to be continuous. /// During the aggregation, the values extracted from the fast_field `field` will be checked /// against each bucket range. Note that this aggregation includes the from value and excludes the /// to value for each range. @@ -64,6 +64,9 @@ pub struct RangeAggregation { #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] /// The range for one range bucket. pub struct RangeAggregationRange { + /// Custom key for the range bucket + #[serde(skip_serializing_if = "Option::is_none", default)] + pub key: Option, /// The from range value, which is inclusive in the range. /// None equals to an open ended interval. #[serde(skip_serializing_if = "Option::is_none", default)] @@ -86,7 +89,26 @@ impl From> for RangeAggregationRange { } else { Some(range.end) }; - RangeAggregationRange { from, to } + RangeAggregationRange { + key: None, + from, + to, + } + } +} + +#[derive(Clone, Debug, PartialEq)] +/// Internally used u64 range for one range bucket. +pub(crate) struct InternalRangeAggregationRange { + /// Custom key for the range bucket + key: Option, + /// u64 range value + range: Range, +} + +impl From> for InternalRangeAggregationRange { + fn from(range: Range) -> Self { + InternalRangeAggregationRange { key: None, range } } } @@ -185,15 +207,20 @@ impl SegmentRangeCollector { let buckets: Vec<_> = extend_validate_ranges(&req.ranges, &field_type)? .iter() .map(|range| { - let to = if range.end == u64::MAX { + let key = range + .key + .clone() + .map(|key| Key::Str(key)) + .unwrap_or(range_to_key(&range.range, &field_type)); + let to = if range.range.end == u64::MAX { None } else { - Some(f64_from_fastfield_u64(range.end, &field_type)) + Some(f64_from_fastfield_u64(range.range.end, &field_type)) }; - let from = if range.start == u64::MIN { + let from = if range.range.start == u64::MIN { None } else { - Some(f64_from_fastfield_u64(range.start, &field_type)) + Some(f64_from_fastfield_u64(range.range.start, &field_type)) }; let sub_aggregation = if sub_aggregation.is_empty() { None @@ -203,11 +230,11 @@ impl SegmentRangeCollector { )?) }; Ok(SegmentRangeAndBucketEntry { - range: range.clone(), + range: range.range.clone(), bucket: SegmentRangeBucketEntry { - key: range_to_key(range, &field_type), doc_count: 0, sub_aggregation, + key, from, to, }, @@ -306,7 +333,10 @@ impl SegmentRangeCollector { /// fast field. /// The alternative would be that every value read would be converted to the f64 range, but that is /// more computational expensive when many documents are hit. -fn to_u64_range(range: &RangeAggregationRange, field_type: &Type) -> crate::Result> { +fn to_u64_range( + range: &RangeAggregationRange, + field_type: &Type, +) -> crate::Result { let start = if let Some(from) = range.from { f64_to_fastfield_u64(from, field_type) .ok_or_else(|| TantivyError::InvalidArgument("invalid field type".to_string()))? @@ -321,7 +351,10 @@ fn to_u64_range(range: &RangeAggregationRange, field_type: &Type) -> crate::Resu u64::MAX }; - Ok(start..end) + Ok(InternalRangeAggregationRange { + key: range.key.clone(), + range: start..end, + }) } /// Extends the provided buckets to contain the whole value range, by inserting buckets at the @@ -329,31 +362,32 @@ fn to_u64_range(range: &RangeAggregationRange, field_type: &Type) -> crate::Resu fn extend_validate_ranges( buckets: &[RangeAggregationRange], field_type: &Type, -) -> crate::Result>> { +) -> crate::Result> { let mut converted_buckets = buckets .iter() .map(|range| to_u64_range(range, field_type)) .collect::>>()?; - converted_buckets.sort_by_key(|bucket| bucket.start); - if converted_buckets[0].start != u64::MIN { - converted_buckets.insert(0, u64::MIN..converted_buckets[0].start); + converted_buckets.sort_by_key(|bucket| bucket.range.start); + if converted_buckets[0].range.start != u64::MIN { + converted_buckets.insert(0, (u64::MIN..converted_buckets[0].range.start).into()); } - if converted_buckets[converted_buckets.len() - 1].end != u64::MAX { - converted_buckets.push(converted_buckets[converted_buckets.len() - 1].end..u64::MAX); + if converted_buckets[converted_buckets.len() - 1].range.end != u64::MAX { + converted_buckets + .push((converted_buckets[converted_buckets.len() - 1].range.end..u64::MAX).into()); } // fill up holes in the ranges - let find_hole = |converted_buckets: &[Range]| { + let find_hole = |converted_buckets: &[InternalRangeAggregationRange]| { for (pos, ranges) in converted_buckets.windows(2).enumerate() { - if ranges[0].end > ranges[1].start { + if ranges[0].range.end > ranges[1].range.start { return Err(TantivyError::InvalidArgument(format!( "Overlapping ranges not supported range {:?}, range+1 {:?}", ranges[0], ranges[1] ))); } - if ranges[0].end != ranges[1].start { + if ranges[0].range.end != ranges[1].range.start { return Ok(Some(pos)); } } @@ -361,8 +395,9 @@ fn extend_validate_ranges( }; while let Some(hole_pos) = find_hole(&converted_buckets)? { - let new_range = converted_buckets[hole_pos].end..converted_buckets[hole_pos + 1].start; - converted_buckets.insert(hole_pos + 1, new_range); + let new_range = + converted_buckets[hole_pos].range.end..converted_buckets[hole_pos + 1].range.start; + converted_buckets.insert(hole_pos + 1, new_range.into()); } Ok(converted_buckets) @@ -370,7 +405,7 @@ fn extend_validate_ranges( pub(crate) fn range_to_string(range: &Range, field_type: &Type) -> String { // is_start is there for malformed requests, e.g. ig the user passes the range u64::MIN..0.0, - // it should be rendererd as "*-0" and not "*-*" + // it should be rendered as "*-0" and not "*-*" let to_str = |val: u64, is_start: bool| { if (is_start && val == u64::MIN) || (!is_start && val == u64::MAX) { "*".to_string() @@ -389,16 +424,12 @@ pub(crate) fn range_to_key(range: &Range, field_type: &Type) -> Key { #[cfg(test)] mod tests { - use serde_json::Value; - use super::*; use crate::aggregation::agg_req::{ Aggregation, Aggregations, BucketAggregation, BucketAggregationType, }; - use crate::aggregation::tests::get_test_index_with_num_docs; - use crate::aggregation::AggregationCollector; + use crate::aggregation::tests::{exec_request_with_query, get_test_index_with_num_docs}; use crate::fastfield::FastValue; - use crate::query::AllQuery; pub fn get_collector_from_ranges( ranges: Vec, @@ -437,13 +468,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req, None); - - let reader = index.reader()?; - let searcher = reader.searcher(); - let agg_res = searcher.search(&AllQuery, &collector).unwrap(); - - let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?; + let res = exec_request_with_query(agg_req, &index, None)?; assert_eq!(res["range"]["buckets"][0]["key"], "*-0"); assert_eq!(res["range"]["buckets"][0]["doc_count"], 0); @@ -475,13 +500,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req, None); - - let reader = index.reader()?; - let searcher = reader.searcher(); - let agg_res = searcher.search(&AllQuery, &collector).unwrap(); - - let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?; + let res = exec_request_with_query(agg_req, &index, None)?; assert_eq!( res, @@ -500,6 +519,94 @@ mod tests { Ok(()) } + #[test] + fn range_custom_key_test() -> crate::Result<()> { + let index = get_test_index_with_num_docs(false, 100)?; + + let agg_req: Aggregations = vec![( + "range".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Range(RangeAggregation { + field: "fraction_f64".to_string(), + ranges: vec![ + RangeAggregationRange { + key: Some("custom-key-0-to-0.1".to_string()), + from: Some(0f64), + to: Some(0.1f64), + }, + RangeAggregationRange { + key: None, + from: Some(0.1f64), + to: Some(0.2f64), + }, + ], + keyed: false, + }), + sub_aggregation: Default::default(), + }), + )] + .into_iter() + .collect(); + + let res = exec_request_with_query(agg_req, &index, None)?; + + assert_eq!( + res, + json!({ + "range": { + "buckets": [ + {"key": "*-0", "doc_count": 0, "to": 0.0}, + {"key": "custom-key-0-to-0.1", "doc_count": 10, "from": 0.0, "to": 0.1}, + {"key": "0.1-0.2", "doc_count": 10, "from": 0.1, "to": 0.2}, + {"key": "0.2-*", "doc_count": 80, "from": 0.2} + ] + } + }) + ); + + Ok(()) + } + + #[test] + fn range_custom_key_keyed_buckets_test() -> crate::Result<()> { + let index = get_test_index_with_num_docs(false, 100)?; + + let agg_req: Aggregations = vec![( + "range".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Range(RangeAggregation { + field: "fraction_f64".to_string(), + ranges: vec![RangeAggregationRange { + key: Some("custom-key-0-to-0.1".to_string()), + from: Some(0f64), + to: Some(0.1f64), + }], + keyed: true, + }), + sub_aggregation: Default::default(), + }), + )] + .into_iter() + .collect(); + + let res = exec_request_with_query(agg_req, &index, None)?; + + assert_eq!( + res, + json!({ + "range": { + "buckets": { + "*-0": { "key": "*-0", "doc_count": 0, "to": 0.0}, + "custom-key-0-to-0.1": {"key": "custom-key-0-to-0.1", "doc_count": 10, "from": 0.0, "to": 0.1}, + "0.1-*": {"key": "0.1-*", "doc_count": 90, "from": 0.1}, + } + } + }) + ); + + Ok(()) + } + #[test] fn bucket_test_extend_range_hole() { let buckets = vec![(10f64..20f64).into(), (30f64..40f64).into()]; @@ -578,6 +685,7 @@ mod tests { let ranges = vec![ RangeAggregationRange { + key: None, to: Some(10.0), from: None, }, @@ -587,11 +695,13 @@ mod tests { let ranges = vec![ RangeAggregationRange { + key: None, to: Some(10.0), from: None, }, (10.0..100.0).into(), RangeAggregationRange { + key: None, to: None, from: Some(100.0), }, diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index 9c708b891..1d6e35383 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -377,7 +377,7 @@ mod tests { searcher.search(&AllQuery, &collector)? }; - // Test serialization/deserialization rountrip + // Test serialization/deserialization roundtrip let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?; Ok(res) }