From 195309a5579df8151fa6c276450145235db4995d Mon Sep 17 00:00:00 2001 From: k-yomo Date: Wed, 27 Jul 2022 09:13:50 +0900 Subject: [PATCH 1/3] Add support for custom key param for range aggregation --- src/aggregation/bucket/range.rs | 188 ++++++++++++++++++++++++++++---- 1 file changed, 167 insertions(+), 21 deletions(-) diff --git a/src/aggregation/bucket/range.rs b/src/aggregation/bucket/range.rs index fe936004d..bc2a04793 100644 --- a/src/aggregation/bucket/range.rs +++ b/src/aggregation/bucket/range.rs @@ -64,6 +64,9 @@ pub struct RangeAggregation { #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] /// The range for one range bucket. pub struct RangeAggregationRange { + /// Custom key for the range bucket + #[serde(skip_serializing_if = "Option::is_none", default)] + pub key: Option, /// The from range value, which is inclusive in the range. /// None equals to an open ended interval. #[serde(skip_serializing_if = "Option::is_none", default)] @@ -86,10 +89,23 @@ impl From> for RangeAggregationRange { } else { Some(range.end) }; - RangeAggregationRange { from, to } + RangeAggregationRange { + key: None, + from, + to, + } } } +#[derive(Clone, Debug, PartialEq)] +/// Internally used u64 range for one range bucket. +pub(crate) struct InternalRangeAggregationRange { + /// Custom key for the range bucket + key: Option, + /// u64 range value + range: Range, +} + #[derive(Clone, Debug, PartialEq)] pub(crate) struct SegmentRangeAndBucketEntry { range: Range, @@ -185,15 +201,20 @@ impl SegmentRangeCollector { let buckets: Vec<_> = extend_validate_ranges(&req.ranges, &field_type)? .iter() .map(|range| { - let to = if range.end == u64::MAX { + let key = range + .key + .clone() + .map(|key| Key::Str(key)) + .unwrap_or(range_to_key(&range.range, &field_type)); + let to = if range.range.end == u64::MAX { None } else { - Some(f64_from_fastfield_u64(range.end, &field_type)) + Some(f64_from_fastfield_u64(range.range.end, &field_type)) }; - let from = if range.start == u64::MIN { + let from = if range.range.start == u64::MIN { None } else { - Some(f64_from_fastfield_u64(range.start, &field_type)) + Some(f64_from_fastfield_u64(range.range.start, &field_type)) }; let sub_aggregation = if sub_aggregation.is_empty() { None @@ -203,11 +224,11 @@ impl SegmentRangeCollector { )?) }; Ok(SegmentRangeAndBucketEntry { - range: range.clone(), + range: range.range.clone(), bucket: SegmentRangeBucketEntry { - key: range_to_key(range, &field_type), doc_count: 0, sub_aggregation, + key, from, to, }, @@ -306,7 +327,10 @@ impl SegmentRangeCollector { /// fast field. /// The alternative would be that every value read would be converted to the f64 range, but that is /// more computational expensive when many documents are hit. -fn to_u64_range(range: &RangeAggregationRange, field_type: &Type) -> crate::Result> { +fn to_u64_range( + range: &RangeAggregationRange, + field_type: &Type, +) -> crate::Result { let start = if let Some(from) = range.from { f64_to_fastfield_u64(from, field_type) .ok_or_else(|| TantivyError::InvalidArgument("invalid field type".to_string()))? @@ -321,7 +345,10 @@ fn to_u64_range(range: &RangeAggregationRange, field_type: &Type) -> crate::Resu u64::MAX }; - Ok(start..end) + Ok(InternalRangeAggregationRange { + key: range.key.clone(), + range: start..end, + }) } /// Extends the provided buckets to contain the whole value range, by inserting buckets at the @@ -329,31 +356,40 @@ fn to_u64_range(range: &RangeAggregationRange, field_type: &Type) -> crate::Resu fn extend_validate_ranges( buckets: &[RangeAggregationRange], field_type: &Type, -) -> crate::Result>> { +) -> crate::Result> { let mut converted_buckets = buckets .iter() .map(|range| to_u64_range(range, field_type)) .collect::>>()?; - converted_buckets.sort_by_key(|bucket| bucket.start); - if converted_buckets[0].start != u64::MIN { - converted_buckets.insert(0, u64::MIN..converted_buckets[0].start); + converted_buckets.sort_by_key(|bucket| bucket.range.start); + if converted_buckets[0].range.start != u64::MIN { + converted_buckets.insert( + 0, + InternalRangeAggregationRange { + key: None, + range: u64::MIN..converted_buckets[0].range.start, + }, + ); } - if converted_buckets[converted_buckets.len() - 1].end != u64::MAX { - converted_buckets.push(converted_buckets[converted_buckets.len() - 1].end..u64::MAX); + if converted_buckets[converted_buckets.len() - 1].range.end != u64::MAX { + converted_buckets.push(InternalRangeAggregationRange { + key: None, + range: converted_buckets[converted_buckets.len() - 1].range.end..u64::MAX, + }); } // fill up holes in the ranges - let find_hole = |converted_buckets: &[Range]| { + let find_hole = |converted_buckets: &[InternalRangeAggregationRange]| { for (pos, ranges) in converted_buckets.windows(2).enumerate() { - if ranges[0].end > ranges[1].start { + if ranges[0].range.end > ranges[1].range.start { return Err(TantivyError::InvalidArgument(format!( "Overlapping ranges not supported range {:?}, range+1 {:?}", ranges[0], ranges[1] ))); } - if ranges[0].end != ranges[1].start { + if ranges[0].range.end != ranges[1].range.start { return Ok(Some(pos)); } } @@ -361,8 +397,15 @@ fn extend_validate_ranges( }; while let Some(hole_pos) = find_hole(&converted_buckets)? { - let new_range = converted_buckets[hole_pos].end..converted_buckets[hole_pos + 1].start; - converted_buckets.insert(hole_pos + 1, new_range); + let new_range = + converted_buckets[hole_pos].range.end..converted_buckets[hole_pos + 1].range.start; + converted_buckets.insert( + hole_pos + 1, + InternalRangeAggregationRange { + key: None, + range: new_range, + }, + ); } Ok(converted_buckets) @@ -370,7 +413,7 @@ fn extend_validate_ranges( pub(crate) fn range_to_string(range: &Range, field_type: &Type) -> String { // is_start is there for malformed requests, e.g. ig the user passes the range u64::MIN..0.0, - // it should be rendererd as "*-0" and not "*-*" + // it should be rendered as "*-0" and not "*-*" let to_str = |val: u64, is_start: bool| { if (is_start && val == u64::MIN) || (!is_start && val == u64::MAX) { "*".to_string() @@ -500,6 +543,106 @@ mod tests { Ok(()) } + #[test] + fn range_custom_key_test() -> crate::Result<()> { + let index = get_test_index_with_num_docs(false, 100)?; + + let agg_req: Aggregations = vec![( + "range".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Range(RangeAggregation { + field: "fraction_f64".to_string(), + ranges: vec![ + RangeAggregationRange { + key: Some("custom-key-0-to-0.1".to_string()), + from: Some(0f64), + to: Some(0.1f64), + }, + RangeAggregationRange { + key: None, + from: Some(0.1f64), + to: Some(0.2f64), + }, + ], + keyed: false, + }), + sub_aggregation: Default::default(), + }), + )] + .into_iter() + .collect(); + + let collector = AggregationCollector::from_aggs(agg_req, None); + + let reader = index.reader()?; + let searcher = reader.searcher(); + let agg_res = searcher.search(&AllQuery, &collector).unwrap(); + + let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?; + + assert_eq!( + res, + json!({ + "range": { + "buckets": [ + {"key": "*-0", "doc_count": 0, "to": 0.0}, + {"key": "custom-key-0-to-0.1", "doc_count": 10, "from": 0.0, "to": 0.1}, + {"key": "0.1-0.2", "doc_count": 10, "from": 0.1, "to": 0.2}, + {"key": "0.2-*", "doc_count": 80, "from": 0.2} + ] + } + }) + ); + + Ok(()) + } + + #[test] + fn range_custom_key_keyed_buckets_test() -> crate::Result<()> { + let index = get_test_index_with_num_docs(false, 100)?; + + let agg_req: Aggregations = vec![( + "range".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Range(RangeAggregation { + field: "fraction_f64".to_string(), + ranges: vec![RangeAggregationRange { + key: Some("custom-key-0-to-0.1".to_string()), + from: Some(0f64), + to: Some(0.1f64), + }], + keyed: true, + }), + sub_aggregation: Default::default(), + }), + )] + .into_iter() + .collect(); + + let collector = AggregationCollector::from_aggs(agg_req, None); + + let reader = index.reader()?; + let searcher = reader.searcher(); + let agg_res = searcher.search(&AllQuery, &collector).unwrap(); + + let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?; + + assert_eq!( + res, + json!({ + "range": { + "buckets": { + "*-0": { "key": "*-0", "doc_count": 0, "to": 0.0}, + "custom-key-0-to-0.1": {"key": "custom-key-0-to-0.1", "doc_count": 10, "from": 0.0, "to": 0.1}, + "0.1-*": {"key": "0.1-*", "doc_count": 90, "from": 0.1}, + } + } + }) + ); + + Ok(()) + } + #[test] fn bucket_test_extend_range_hole() { let buckets = vec![(10f64..20f64).into(), (30f64..40f64).into()]; @@ -578,6 +721,7 @@ mod tests { let ranges = vec![ RangeAggregationRange { + key: None, to: Some(10.0), from: None, }, @@ -587,11 +731,13 @@ mod tests { let ranges = vec![ RangeAggregationRange { + key: None, to: Some(10.0), from: None, }, (10.0..100.0).into(), RangeAggregationRange { + key: None, to: None, from: Some(100.0), }, From 704d0a8d8bc3adb4efb1dbb8c55916b98ccd5b22 Mon Sep 17 00:00:00 2001 From: k-yomo Date: Thu, 28 Jul 2022 06:31:25 +0900 Subject: [PATCH 2/3] Refactor range aggregation tests --- src/aggregation/bucket/range.rs | 38 +++++---------------------------- src/aggregation/mod.rs | 2 +- 2 files changed, 6 insertions(+), 34 deletions(-) diff --git a/src/aggregation/bucket/range.rs b/src/aggregation/bucket/range.rs index bc2a04793..24c5c9715 100644 --- a/src/aggregation/bucket/range.rs +++ b/src/aggregation/bucket/range.rs @@ -432,16 +432,12 @@ pub(crate) fn range_to_key(range: &Range, field_type: &Type) -> Key { #[cfg(test)] mod tests { - use serde_json::Value; - use super::*; use crate::aggregation::agg_req::{ Aggregation, Aggregations, BucketAggregation, BucketAggregationType, }; - use crate::aggregation::tests::get_test_index_with_num_docs; - use crate::aggregation::AggregationCollector; + use crate::aggregation::tests::{exec_request_with_query, get_test_index_with_num_docs}; use crate::fastfield::FastValue; - use crate::query::AllQuery; pub fn get_collector_from_ranges( ranges: Vec, @@ -480,13 +476,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req, None); - - let reader = index.reader()?; - let searcher = reader.searcher(); - let agg_res = searcher.search(&AllQuery, &collector).unwrap(); - - let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?; + let res = exec_request_with_query(agg_req, &index, None)?; assert_eq!(res["range"]["buckets"][0]["key"], "*-0"); assert_eq!(res["range"]["buckets"][0]["doc_count"], 0); @@ -518,13 +508,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req, None); - - let reader = index.reader()?; - let searcher = reader.searcher(); - let agg_res = searcher.search(&AllQuery, &collector).unwrap(); - - let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?; + let res = exec_request_with_query(agg_req, &index, None)?; assert_eq!( res, @@ -572,13 +556,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req, None); - - let reader = index.reader()?; - let searcher = reader.searcher(); - let agg_res = searcher.search(&AllQuery, &collector).unwrap(); - - let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?; + let res = exec_request_with_query(agg_req, &index, None)?; assert_eq!( res, @@ -619,13 +597,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req, None); - - let reader = index.reader()?; - let searcher = reader.searcher(); - let agg_res = searcher.search(&AllQuery, &collector).unwrap(); - - let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?; + let res = exec_request_with_query(agg_req, &index, None)?; assert_eq!( res, diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index 9c708b891..1d6e35383 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -377,7 +377,7 @@ mod tests { searcher.search(&AllQuery, &collector)? }; - // Test serialization/deserialization rountrip + // Test serialization/deserialization roundtrip let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?; Ok(res) } From 099e62615624e5cc0c34302f25670a5e8c05ce0c Mon Sep 17 00:00:00 2001 From: k-yomo Date: Fri, 29 Jul 2022 05:41:29 +0900 Subject: [PATCH 3/3] Refactor InternalRangeAggregationRange initialization with From trait --- src/aggregation/bucket/range.rs | 30 +++++++++++------------------- 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/src/aggregation/bucket/range.rs b/src/aggregation/bucket/range.rs index 24c5c9715..e32a8eb4e 100644 --- a/src/aggregation/bucket/range.rs +++ b/src/aggregation/bucket/range.rs @@ -18,7 +18,7 @@ use crate::{DocId, TantivyError}; /// Provide user-defined buckets to aggregate on. /// Two special buckets will automatically be created to cover the whole range of values. -/// The provided buckets have to be continous. +/// The provided buckets have to be continuous. /// During the aggregation, the values extracted from the fast_field `field` will be checked /// against each bucket range. Note that this aggregation includes the from value and excludes the /// to value for each range. @@ -106,6 +106,12 @@ pub(crate) struct InternalRangeAggregationRange { range: Range, } +impl From> for InternalRangeAggregationRange { + fn from(range: Range) -> Self { + InternalRangeAggregationRange { key: None, range } + } +} + #[derive(Clone, Debug, PartialEq)] pub(crate) struct SegmentRangeAndBucketEntry { range: Range, @@ -364,20 +370,12 @@ fn extend_validate_ranges( converted_buckets.sort_by_key(|bucket| bucket.range.start); if converted_buckets[0].range.start != u64::MIN { - converted_buckets.insert( - 0, - InternalRangeAggregationRange { - key: None, - range: u64::MIN..converted_buckets[0].range.start, - }, - ); + converted_buckets.insert(0, (u64::MIN..converted_buckets[0].range.start).into()); } if converted_buckets[converted_buckets.len() - 1].range.end != u64::MAX { - converted_buckets.push(InternalRangeAggregationRange { - key: None, - range: converted_buckets[converted_buckets.len() - 1].range.end..u64::MAX, - }); + converted_buckets + .push((converted_buckets[converted_buckets.len() - 1].range.end..u64::MAX).into()); } // fill up holes in the ranges @@ -399,13 +397,7 @@ fn extend_validate_ranges( while let Some(hole_pos) = find_hole(&converted_buckets)? { let new_range = converted_buckets[hole_pos].range.end..converted_buckets[hole_pos + 1].range.start; - converted_buckets.insert( - hole_pos + 1, - InternalRangeAggregationRange { - key: None, - range: new_range, - }, - ); + converted_buckets.insert(hole_pos + 1, new_range.into()); } Ok(converted_buckets)