mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2026-05-23 19:50:42 +00:00
Merge pull request #1426 from k-yomo/support-custom-key-in-range-aggregation
Add support for custom key param for range aggregation
This commit is contained in:
@@ -18,7 +18,7 @@ use crate::{DocId, TantivyError};
|
||||
|
||||
/// Provide user-defined buckets to aggregate on.
|
||||
/// Two special buckets will automatically be created to cover the whole range of values.
|
||||
/// The provided buckets have to be continous.
|
||||
/// The provided buckets have to be continuous.
|
||||
/// During the aggregation, the values extracted from the fast_field `field` will be checked
|
||||
/// against each bucket range. Note that this aggregation includes the from value and excludes the
|
||||
/// to value for each range.
|
||||
@@ -64,6 +64,9 @@ pub struct RangeAggregation {
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
/// The range for one range bucket.
|
||||
pub struct RangeAggregationRange {
|
||||
/// Custom key for the range bucket
|
||||
#[serde(skip_serializing_if = "Option::is_none", default)]
|
||||
pub key: Option<String>,
|
||||
/// The from range value, which is inclusive in the range.
|
||||
/// None equals to an open ended interval.
|
||||
#[serde(skip_serializing_if = "Option::is_none", default)]
|
||||
@@ -86,7 +89,26 @@ impl From<Range<f64>> for RangeAggregationRange {
|
||||
} else {
|
||||
Some(range.end)
|
||||
};
|
||||
RangeAggregationRange { from, to }
|
||||
RangeAggregationRange {
|
||||
key: None,
|
||||
from,
|
||||
to,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
/// Internally used u64 range for one range bucket.
|
||||
pub(crate) struct InternalRangeAggregationRange {
|
||||
/// Custom key for the range bucket
|
||||
key: Option<String>,
|
||||
/// u64 range value
|
||||
range: Range<u64>,
|
||||
}
|
||||
|
||||
impl From<Range<u64>> for InternalRangeAggregationRange {
|
||||
fn from(range: Range<u64>) -> Self {
|
||||
InternalRangeAggregationRange { key: None, range }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -185,15 +207,20 @@ impl SegmentRangeCollector {
|
||||
let buckets: Vec<_> = extend_validate_ranges(&req.ranges, &field_type)?
|
||||
.iter()
|
||||
.map(|range| {
|
||||
let to = if range.end == u64::MAX {
|
||||
let key = range
|
||||
.key
|
||||
.clone()
|
||||
.map(|key| Key::Str(key))
|
||||
.unwrap_or(range_to_key(&range.range, &field_type));
|
||||
let to = if range.range.end == u64::MAX {
|
||||
None
|
||||
} else {
|
||||
Some(f64_from_fastfield_u64(range.end, &field_type))
|
||||
Some(f64_from_fastfield_u64(range.range.end, &field_type))
|
||||
};
|
||||
let from = if range.start == u64::MIN {
|
||||
let from = if range.range.start == u64::MIN {
|
||||
None
|
||||
} else {
|
||||
Some(f64_from_fastfield_u64(range.start, &field_type))
|
||||
Some(f64_from_fastfield_u64(range.range.start, &field_type))
|
||||
};
|
||||
let sub_aggregation = if sub_aggregation.is_empty() {
|
||||
None
|
||||
@@ -203,11 +230,11 @@ impl SegmentRangeCollector {
|
||||
)?)
|
||||
};
|
||||
Ok(SegmentRangeAndBucketEntry {
|
||||
range: range.clone(),
|
||||
range: range.range.clone(),
|
||||
bucket: SegmentRangeBucketEntry {
|
||||
key: range_to_key(range, &field_type),
|
||||
doc_count: 0,
|
||||
sub_aggregation,
|
||||
key,
|
||||
from,
|
||||
to,
|
||||
},
|
||||
@@ -306,7 +333,10 @@ impl SegmentRangeCollector {
|
||||
/// fast field.
|
||||
/// The alternative would be that every value read would be converted to the f64 range, but that is
|
||||
/// more computational expensive when many documents are hit.
|
||||
fn to_u64_range(range: &RangeAggregationRange, field_type: &Type) -> crate::Result<Range<u64>> {
|
||||
fn to_u64_range(
|
||||
range: &RangeAggregationRange,
|
||||
field_type: &Type,
|
||||
) -> crate::Result<InternalRangeAggregationRange> {
|
||||
let start = if let Some(from) = range.from {
|
||||
f64_to_fastfield_u64(from, field_type)
|
||||
.ok_or_else(|| TantivyError::InvalidArgument("invalid field type".to_string()))?
|
||||
@@ -321,7 +351,10 @@ fn to_u64_range(range: &RangeAggregationRange, field_type: &Type) -> crate::Resu
|
||||
u64::MAX
|
||||
};
|
||||
|
||||
Ok(start..end)
|
||||
Ok(InternalRangeAggregationRange {
|
||||
key: range.key.clone(),
|
||||
range: start..end,
|
||||
})
|
||||
}
|
||||
|
||||
/// Extends the provided buckets to contain the whole value range, by inserting buckets at the
|
||||
@@ -329,31 +362,32 @@ fn to_u64_range(range: &RangeAggregationRange, field_type: &Type) -> crate::Resu
|
||||
fn extend_validate_ranges(
|
||||
buckets: &[RangeAggregationRange],
|
||||
field_type: &Type,
|
||||
) -> crate::Result<Vec<Range<u64>>> {
|
||||
) -> crate::Result<Vec<InternalRangeAggregationRange>> {
|
||||
let mut converted_buckets = buckets
|
||||
.iter()
|
||||
.map(|range| to_u64_range(range, field_type))
|
||||
.collect::<crate::Result<Vec<_>>>()?;
|
||||
|
||||
converted_buckets.sort_by_key(|bucket| bucket.start);
|
||||
if converted_buckets[0].start != u64::MIN {
|
||||
converted_buckets.insert(0, u64::MIN..converted_buckets[0].start);
|
||||
converted_buckets.sort_by_key(|bucket| bucket.range.start);
|
||||
if converted_buckets[0].range.start != u64::MIN {
|
||||
converted_buckets.insert(0, (u64::MIN..converted_buckets[0].range.start).into());
|
||||
}
|
||||
|
||||
if converted_buckets[converted_buckets.len() - 1].end != u64::MAX {
|
||||
converted_buckets.push(converted_buckets[converted_buckets.len() - 1].end..u64::MAX);
|
||||
if converted_buckets[converted_buckets.len() - 1].range.end != u64::MAX {
|
||||
converted_buckets
|
||||
.push((converted_buckets[converted_buckets.len() - 1].range.end..u64::MAX).into());
|
||||
}
|
||||
|
||||
// fill up holes in the ranges
|
||||
let find_hole = |converted_buckets: &[Range<u64>]| {
|
||||
let find_hole = |converted_buckets: &[InternalRangeAggregationRange]| {
|
||||
for (pos, ranges) in converted_buckets.windows(2).enumerate() {
|
||||
if ranges[0].end > ranges[1].start {
|
||||
if ranges[0].range.end > ranges[1].range.start {
|
||||
return Err(TantivyError::InvalidArgument(format!(
|
||||
"Overlapping ranges not supported range {:?}, range+1 {:?}",
|
||||
ranges[0], ranges[1]
|
||||
)));
|
||||
}
|
||||
if ranges[0].end != ranges[1].start {
|
||||
if ranges[0].range.end != ranges[1].range.start {
|
||||
return Ok(Some(pos));
|
||||
}
|
||||
}
|
||||
@@ -361,8 +395,9 @@ fn extend_validate_ranges(
|
||||
};
|
||||
|
||||
while let Some(hole_pos) = find_hole(&converted_buckets)? {
|
||||
let new_range = converted_buckets[hole_pos].end..converted_buckets[hole_pos + 1].start;
|
||||
converted_buckets.insert(hole_pos + 1, new_range);
|
||||
let new_range =
|
||||
converted_buckets[hole_pos].range.end..converted_buckets[hole_pos + 1].range.start;
|
||||
converted_buckets.insert(hole_pos + 1, new_range.into());
|
||||
}
|
||||
|
||||
Ok(converted_buckets)
|
||||
@@ -370,7 +405,7 @@ fn extend_validate_ranges(
|
||||
|
||||
pub(crate) fn range_to_string(range: &Range<u64>, field_type: &Type) -> String {
|
||||
// is_start is there for malformed requests, e.g. ig the user passes the range u64::MIN..0.0,
|
||||
// it should be rendererd as "*-0" and not "*-*"
|
||||
// it should be rendered as "*-0" and not "*-*"
|
||||
let to_str = |val: u64, is_start: bool| {
|
||||
if (is_start && val == u64::MIN) || (!is_start && val == u64::MAX) {
|
||||
"*".to_string()
|
||||
@@ -389,16 +424,12 @@ pub(crate) fn range_to_key(range: &Range<u64>, field_type: &Type) -> Key {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use serde_json::Value;
|
||||
|
||||
use super::*;
|
||||
use crate::aggregation::agg_req::{
|
||||
Aggregation, Aggregations, BucketAggregation, BucketAggregationType,
|
||||
};
|
||||
use crate::aggregation::tests::get_test_index_with_num_docs;
|
||||
use crate::aggregation::AggregationCollector;
|
||||
use crate::aggregation::tests::{exec_request_with_query, get_test_index_with_num_docs};
|
||||
use crate::fastfield::FastValue;
|
||||
use crate::query::AllQuery;
|
||||
|
||||
pub fn get_collector_from_ranges(
|
||||
ranges: Vec<RangeAggregationRange>,
|
||||
@@ -437,13 +468,7 @@ mod tests {
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let collector = AggregationCollector::from_aggs(agg_req, None);
|
||||
|
||||
let reader = index.reader()?;
|
||||
let searcher = reader.searcher();
|
||||
let agg_res = searcher.search(&AllQuery, &collector).unwrap();
|
||||
|
||||
let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?;
|
||||
let res = exec_request_with_query(agg_req, &index, None)?;
|
||||
|
||||
assert_eq!(res["range"]["buckets"][0]["key"], "*-0");
|
||||
assert_eq!(res["range"]["buckets"][0]["doc_count"], 0);
|
||||
@@ -475,13 +500,7 @@ mod tests {
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let collector = AggregationCollector::from_aggs(agg_req, None);
|
||||
|
||||
let reader = index.reader()?;
|
||||
let searcher = reader.searcher();
|
||||
let agg_res = searcher.search(&AllQuery, &collector).unwrap();
|
||||
|
||||
let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?;
|
||||
let res = exec_request_with_query(agg_req, &index, None)?;
|
||||
|
||||
assert_eq!(
|
||||
res,
|
||||
@@ -500,6 +519,94 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn range_custom_key_test() -> crate::Result<()> {
|
||||
let index = get_test_index_with_num_docs(false, 100)?;
|
||||
|
||||
let agg_req: Aggregations = vec![(
|
||||
"range".to_string(),
|
||||
Aggregation::Bucket(BucketAggregation {
|
||||
bucket_agg: BucketAggregationType::Range(RangeAggregation {
|
||||
field: "fraction_f64".to_string(),
|
||||
ranges: vec![
|
||||
RangeAggregationRange {
|
||||
key: Some("custom-key-0-to-0.1".to_string()),
|
||||
from: Some(0f64),
|
||||
to: Some(0.1f64),
|
||||
},
|
||||
RangeAggregationRange {
|
||||
key: None,
|
||||
from: Some(0.1f64),
|
||||
to: Some(0.2f64),
|
||||
},
|
||||
],
|
||||
keyed: false,
|
||||
}),
|
||||
sub_aggregation: Default::default(),
|
||||
}),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let res = exec_request_with_query(agg_req, &index, None)?;
|
||||
|
||||
assert_eq!(
|
||||
res,
|
||||
json!({
|
||||
"range": {
|
||||
"buckets": [
|
||||
{"key": "*-0", "doc_count": 0, "to": 0.0},
|
||||
{"key": "custom-key-0-to-0.1", "doc_count": 10, "from": 0.0, "to": 0.1},
|
||||
{"key": "0.1-0.2", "doc_count": 10, "from": 0.1, "to": 0.2},
|
||||
{"key": "0.2-*", "doc_count": 80, "from": 0.2}
|
||||
]
|
||||
}
|
||||
})
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn range_custom_key_keyed_buckets_test() -> crate::Result<()> {
|
||||
let index = get_test_index_with_num_docs(false, 100)?;
|
||||
|
||||
let agg_req: Aggregations = vec![(
|
||||
"range".to_string(),
|
||||
Aggregation::Bucket(BucketAggregation {
|
||||
bucket_agg: BucketAggregationType::Range(RangeAggregation {
|
||||
field: "fraction_f64".to_string(),
|
||||
ranges: vec![RangeAggregationRange {
|
||||
key: Some("custom-key-0-to-0.1".to_string()),
|
||||
from: Some(0f64),
|
||||
to: Some(0.1f64),
|
||||
}],
|
||||
keyed: true,
|
||||
}),
|
||||
sub_aggregation: Default::default(),
|
||||
}),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let res = exec_request_with_query(agg_req, &index, None)?;
|
||||
|
||||
assert_eq!(
|
||||
res,
|
||||
json!({
|
||||
"range": {
|
||||
"buckets": {
|
||||
"*-0": { "key": "*-0", "doc_count": 0, "to": 0.0},
|
||||
"custom-key-0-to-0.1": {"key": "custom-key-0-to-0.1", "doc_count": 10, "from": 0.0, "to": 0.1},
|
||||
"0.1-*": {"key": "0.1-*", "doc_count": 90, "from": 0.1},
|
||||
}
|
||||
}
|
||||
})
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bucket_test_extend_range_hole() {
|
||||
let buckets = vec![(10f64..20f64).into(), (30f64..40f64).into()];
|
||||
@@ -578,6 +685,7 @@ mod tests {
|
||||
|
||||
let ranges = vec![
|
||||
RangeAggregationRange {
|
||||
key: None,
|
||||
to: Some(10.0),
|
||||
from: None,
|
||||
},
|
||||
@@ -587,11 +695,13 @@ mod tests {
|
||||
|
||||
let ranges = vec![
|
||||
RangeAggregationRange {
|
||||
key: None,
|
||||
to: Some(10.0),
|
||||
from: None,
|
||||
},
|
||||
(10.0..100.0).into(),
|
||||
RangeAggregationRange {
|
||||
key: None,
|
||||
to: None,
|
||||
from: Some(100.0),
|
||||
},
|
||||
|
||||
@@ -377,7 +377,7 @@ mod tests {
|
||||
searcher.search(&AllQuery, &collector)?
|
||||
};
|
||||
|
||||
// Test serialization/deserialization rountrip
|
||||
// Test serialization/deserialization roundtrip
|
||||
let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?;
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user