Merge pull request #1426 from k-yomo/support-custom-key-in-range-aggregation

Add support for custom key param for range aggregation
This commit is contained in:
PSeitz
2022-08-03 04:31:02 -07:00
committed by GitHub
2 changed files with 152 additions and 42 deletions

View File

@@ -18,7 +18,7 @@ use crate::{DocId, TantivyError};
/// Provide user-defined buckets to aggregate on.
/// Two special buckets will automatically be created to cover the whole range of values.
/// The provided buckets have to be continous.
/// The provided buckets have to be continuous.
/// During the aggregation, the values extracted from the fast_field `field` will be checked
/// against each bucket range. Note that this aggregation includes the from value and excludes the
/// to value for each range.
@@ -64,6 +64,9 @@ pub struct RangeAggregation {
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
/// The range for one range bucket.
pub struct RangeAggregationRange {
/// Custom key for the range bucket
#[serde(skip_serializing_if = "Option::is_none", default)]
pub key: Option<String>,
/// The from range value, which is inclusive in the range.
/// None equals to an open ended interval.
#[serde(skip_serializing_if = "Option::is_none", default)]
@@ -86,7 +89,26 @@ impl From<Range<f64>> for RangeAggregationRange {
} else {
Some(range.end)
};
RangeAggregationRange { from, to }
RangeAggregationRange {
key: None,
from,
to,
}
}
}
#[derive(Clone, Debug, PartialEq)]
/// Internally used u64 range for one range bucket.
pub(crate) struct InternalRangeAggregationRange {
/// Custom key for the range bucket
key: Option<String>,
/// u64 range value
range: Range<u64>,
}
impl From<Range<u64>> for InternalRangeAggregationRange {
fn from(range: Range<u64>) -> Self {
InternalRangeAggregationRange { key: None, range }
}
}
@@ -185,15 +207,20 @@ impl SegmentRangeCollector {
let buckets: Vec<_> = extend_validate_ranges(&req.ranges, &field_type)?
.iter()
.map(|range| {
let to = if range.end == u64::MAX {
let key = range
.key
.clone()
.map(|key| Key::Str(key))
.unwrap_or(range_to_key(&range.range, &field_type));
let to = if range.range.end == u64::MAX {
None
} else {
Some(f64_from_fastfield_u64(range.end, &field_type))
Some(f64_from_fastfield_u64(range.range.end, &field_type))
};
let from = if range.start == u64::MIN {
let from = if range.range.start == u64::MIN {
None
} else {
Some(f64_from_fastfield_u64(range.start, &field_type))
Some(f64_from_fastfield_u64(range.range.start, &field_type))
};
let sub_aggregation = if sub_aggregation.is_empty() {
None
@@ -203,11 +230,11 @@ impl SegmentRangeCollector {
)?)
};
Ok(SegmentRangeAndBucketEntry {
range: range.clone(),
range: range.range.clone(),
bucket: SegmentRangeBucketEntry {
key: range_to_key(range, &field_type),
doc_count: 0,
sub_aggregation,
key,
from,
to,
},
@@ -306,7 +333,10 @@ impl SegmentRangeCollector {
/// fast field.
/// The alternative would be that every value read would be converted to the f64 range, but that is
/// more computational expensive when many documents are hit.
fn to_u64_range(range: &RangeAggregationRange, field_type: &Type) -> crate::Result<Range<u64>> {
fn to_u64_range(
range: &RangeAggregationRange,
field_type: &Type,
) -> crate::Result<InternalRangeAggregationRange> {
let start = if let Some(from) = range.from {
f64_to_fastfield_u64(from, field_type)
.ok_or_else(|| TantivyError::InvalidArgument("invalid field type".to_string()))?
@@ -321,7 +351,10 @@ fn to_u64_range(range: &RangeAggregationRange, field_type: &Type) -> crate::Resu
u64::MAX
};
Ok(start..end)
Ok(InternalRangeAggregationRange {
key: range.key.clone(),
range: start..end,
})
}
/// Extends the provided buckets to contain the whole value range, by inserting buckets at the
@@ -329,31 +362,32 @@ fn to_u64_range(range: &RangeAggregationRange, field_type: &Type) -> crate::Resu
fn extend_validate_ranges(
buckets: &[RangeAggregationRange],
field_type: &Type,
) -> crate::Result<Vec<Range<u64>>> {
) -> crate::Result<Vec<InternalRangeAggregationRange>> {
let mut converted_buckets = buckets
.iter()
.map(|range| to_u64_range(range, field_type))
.collect::<crate::Result<Vec<_>>>()?;
converted_buckets.sort_by_key(|bucket| bucket.start);
if converted_buckets[0].start != u64::MIN {
converted_buckets.insert(0, u64::MIN..converted_buckets[0].start);
converted_buckets.sort_by_key(|bucket| bucket.range.start);
if converted_buckets[0].range.start != u64::MIN {
converted_buckets.insert(0, (u64::MIN..converted_buckets[0].range.start).into());
}
if converted_buckets[converted_buckets.len() - 1].end != u64::MAX {
converted_buckets.push(converted_buckets[converted_buckets.len() - 1].end..u64::MAX);
if converted_buckets[converted_buckets.len() - 1].range.end != u64::MAX {
converted_buckets
.push((converted_buckets[converted_buckets.len() - 1].range.end..u64::MAX).into());
}
// fill up holes in the ranges
let find_hole = |converted_buckets: &[Range<u64>]| {
let find_hole = |converted_buckets: &[InternalRangeAggregationRange]| {
for (pos, ranges) in converted_buckets.windows(2).enumerate() {
if ranges[0].end > ranges[1].start {
if ranges[0].range.end > ranges[1].range.start {
return Err(TantivyError::InvalidArgument(format!(
"Overlapping ranges not supported range {:?}, range+1 {:?}",
ranges[0], ranges[1]
)));
}
if ranges[0].end != ranges[1].start {
if ranges[0].range.end != ranges[1].range.start {
return Ok(Some(pos));
}
}
@@ -361,8 +395,9 @@ fn extend_validate_ranges(
};
while let Some(hole_pos) = find_hole(&converted_buckets)? {
let new_range = converted_buckets[hole_pos].end..converted_buckets[hole_pos + 1].start;
converted_buckets.insert(hole_pos + 1, new_range);
let new_range =
converted_buckets[hole_pos].range.end..converted_buckets[hole_pos + 1].range.start;
converted_buckets.insert(hole_pos + 1, new_range.into());
}
Ok(converted_buckets)
@@ -370,7 +405,7 @@ fn extend_validate_ranges(
pub(crate) fn range_to_string(range: &Range<u64>, field_type: &Type) -> String {
// is_start is there for malformed requests, e.g. ig the user passes the range u64::MIN..0.0,
// it should be rendererd as "*-0" and not "*-*"
// it should be rendered as "*-0" and not "*-*"
let to_str = |val: u64, is_start: bool| {
if (is_start && val == u64::MIN) || (!is_start && val == u64::MAX) {
"*".to_string()
@@ -389,16 +424,12 @@ pub(crate) fn range_to_key(range: &Range<u64>, field_type: &Type) -> Key {
#[cfg(test)]
mod tests {
use serde_json::Value;
use super::*;
use crate::aggregation::agg_req::{
Aggregation, Aggregations, BucketAggregation, BucketAggregationType,
};
use crate::aggregation::tests::get_test_index_with_num_docs;
use crate::aggregation::AggregationCollector;
use crate::aggregation::tests::{exec_request_with_query, get_test_index_with_num_docs};
use crate::fastfield::FastValue;
use crate::query::AllQuery;
pub fn get_collector_from_ranges(
ranges: Vec<RangeAggregationRange>,
@@ -437,13 +468,7 @@ mod tests {
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req, None);
let reader = index.reader()?;
let searcher = reader.searcher();
let agg_res = searcher.search(&AllQuery, &collector).unwrap();
let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?;
let res = exec_request_with_query(agg_req, &index, None)?;
assert_eq!(res["range"]["buckets"][0]["key"], "*-0");
assert_eq!(res["range"]["buckets"][0]["doc_count"], 0);
@@ -475,13 +500,7 @@ mod tests {
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req, None);
let reader = index.reader()?;
let searcher = reader.searcher();
let agg_res = searcher.search(&AllQuery, &collector).unwrap();
let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?;
let res = exec_request_with_query(agg_req, &index, None)?;
assert_eq!(
res,
@@ -500,6 +519,94 @@ mod tests {
Ok(())
}
#[test]
fn range_custom_key_test() -> crate::Result<()> {
let index = get_test_index_with_num_docs(false, 100)?;
let agg_req: Aggregations = vec![(
"range".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "fraction_f64".to_string(),
ranges: vec![
RangeAggregationRange {
key: Some("custom-key-0-to-0.1".to_string()),
from: Some(0f64),
to: Some(0.1f64),
},
RangeAggregationRange {
key: None,
from: Some(0.1f64),
to: Some(0.2f64),
},
],
keyed: false,
}),
sub_aggregation: Default::default(),
}),
)]
.into_iter()
.collect();
let res = exec_request_with_query(agg_req, &index, None)?;
assert_eq!(
res,
json!({
"range": {
"buckets": [
{"key": "*-0", "doc_count": 0, "to": 0.0},
{"key": "custom-key-0-to-0.1", "doc_count": 10, "from": 0.0, "to": 0.1},
{"key": "0.1-0.2", "doc_count": 10, "from": 0.1, "to": 0.2},
{"key": "0.2-*", "doc_count": 80, "from": 0.2}
]
}
})
);
Ok(())
}
#[test]
fn range_custom_key_keyed_buckets_test() -> crate::Result<()> {
let index = get_test_index_with_num_docs(false, 100)?;
let agg_req: Aggregations = vec![(
"range".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "fraction_f64".to_string(),
ranges: vec![RangeAggregationRange {
key: Some("custom-key-0-to-0.1".to_string()),
from: Some(0f64),
to: Some(0.1f64),
}],
keyed: true,
}),
sub_aggregation: Default::default(),
}),
)]
.into_iter()
.collect();
let res = exec_request_with_query(agg_req, &index, None)?;
assert_eq!(
res,
json!({
"range": {
"buckets": {
"*-0": { "key": "*-0", "doc_count": 0, "to": 0.0},
"custom-key-0-to-0.1": {"key": "custom-key-0-to-0.1", "doc_count": 10, "from": 0.0, "to": 0.1},
"0.1-*": {"key": "0.1-*", "doc_count": 90, "from": 0.1},
}
}
})
);
Ok(())
}
#[test]
fn bucket_test_extend_range_hole() {
let buckets = vec![(10f64..20f64).into(), (30f64..40f64).into()];
@@ -578,6 +685,7 @@ mod tests {
let ranges = vec![
RangeAggregationRange {
key: None,
to: Some(10.0),
from: None,
},
@@ -587,11 +695,13 @@ mod tests {
let ranges = vec![
RangeAggregationRange {
key: None,
to: Some(10.0),
from: None,
},
(10.0..100.0).into(),
RangeAggregationRange {
key: None,
to: None,
from: Some(100.0),
},

View File

@@ -377,7 +377,7 @@ mod tests {
searcher.search(&AllQuery, &collector)?
};
// Test serialization/deserialization rountrip
// Test serialization/deserialization roundtrip
let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?;
Ok(res)
}