Merge pull request #1424 from k-yomo/support-keyed-parameter-in-aggregation

Add support for keyed parameter in range and histgram aggregations
This commit is contained in:
PSeitz
2022-07-27 06:22:29 -07:00
committed by GitHub
9 changed files with 159 additions and 12 deletions

View File

@@ -110,6 +110,7 @@ fn main() -> tantivy::Result<()> {
(9f64..14f64).into(),
(14f64..20f64).into(),
],
..Default::default()
}),
sub_aggregation: sub_agg_req_1.clone(),
}),

View File

@@ -20,6 +20,7 @@
//! bucket_agg: BucketAggregationType::Range(RangeAggregation{
//! field: "score".to_string(),
//! ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()],
//! keyed: false,
//! }),
//! sub_aggregation: Default::default(),
//! }),
@@ -100,6 +101,12 @@ pub(crate) struct BucketAggregationInternal {
}
impl BucketAggregationInternal {
pub(crate) fn as_range(&self) -> Option<&RangeAggregation> {
match &self.bucket_agg {
BucketAggregationType::Range(range) => Some(range),
_ => None,
}
}
pub(crate) fn as_histogram(&self) -> Option<&HistogramAggregation> {
match &self.bucket_agg {
BucketAggregationType::Histogram(histogram) => Some(histogram),
@@ -264,6 +271,7 @@ mod tests {
(7f64..20f64).into(),
(20f64..f64::MAX).into(),
],
keyed: true,
}),
sub_aggregation: Default::default(),
}),
@@ -290,7 +298,8 @@ mod tests {
{
"from": 20.0
}
]
],
"keyed": true
}
}
}"#;
@@ -312,6 +321,7 @@ mod tests {
(7f64..20f64).into(),
(20f64..f64::MAX).into(),
],
..Default::default()
}),
sub_aggregation: Default::default(),
}),
@@ -337,6 +347,7 @@ mod tests {
(7f64..20f64).into(),
(20f64..f64::MAX).into(),
],
..Default::default()
}),
sub_aggregation: agg_req2,
}),

View File

@@ -77,8 +77,7 @@ impl BucketAggregationWithAccessor {
let mut inverted_index = None;
let (accessor, field_type) = match &bucket {
BucketAggregationType::Range(RangeAggregation {
field: field_name,
ranges: _,
field: field_name, ..
}) => get_ff_reader_and_validate(reader, field_name, Cardinality::SingleValue)?,
BucketAggregationType::Histogram(HistogramAggregation {
field: field_name, ..

View File

@@ -6,6 +6,7 @@
use std::collections::HashMap;
use fnv::FnvHashMap;
use serde::{Deserialize, Serialize};
use super::agg_req::BucketAggregationInternal;
@@ -104,7 +105,7 @@ pub enum BucketResult {
/// sub_aggregations.
Range {
/// The range buckets sorted by range.
buckets: Vec<RangeBucketEntry>,
buckets: BucketEntries<RangeBucketEntry>,
},
/// This is the histogram entry for a bucket, which contains a key, count, and optionally
/// sub_aggregations.
@@ -114,7 +115,7 @@ pub enum BucketResult {
/// If there are holes depends on the request, if min_doc_count is 0, then there are no
/// holes between the first and last bucket.
/// See [HistogramAggregation](super::bucket::HistogramAggregation)
buckets: Vec<BucketEntry>,
buckets: BucketEntries<BucketEntry>,
},
/// This is the term result
Terms {
@@ -137,6 +138,17 @@ impl BucketResult {
}
}
/// This is the wrapper of buckets entries, which can be vector or hashmap
/// depending on if it's keyed or not.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum BucketEntries<T> {
/// Vector format bucket entries
Vec(Vec<T>),
/// HashMap format bucket entries
HashMap(FnvHashMap<String, T>),
}
/// This is the default entry for a bucket, which contains a key, count, and optionally
/// sub_aggregations.
///

View File

@@ -48,8 +48,6 @@ use crate::{DocId, TantivyError};
///
/// # Limitations/Compatibility
///
/// The keyed parameter (elasticsearch) is not yet supported.
///
/// # JSON Format
/// ```json
/// {
@@ -117,6 +115,9 @@ pub struct HistogramAggregation {
/// Cannot be set in conjunction with min_doc_count > 0, since the empty buckets from extended
/// bounds would not be returned.
pub extended_bounds: Option<HistogramBounds>,
/// Whether to return the buckets as a hash map
#[serde(default)]
pub keyed: bool,
}
impl HistogramAggregation {
@@ -1395,4 +1396,46 @@ mod tests {
Ok(())
}
#[test]
fn histogram_keyed_buckets_test() -> crate::Result<()> {
let index = get_test_index_with_num_docs(false, 100)?;
let agg_req: Aggregations = vec![(
"histogram".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 50.0,
keyed: true,
..Default::default()
}),
sub_aggregation: Default::default(),
}),
)]
.into_iter()
.collect();
let res = exec_request(agg_req, &index)?;
assert_eq!(
res,
json!({
"histogram": {
"buckets": {
"0": {
"key": 0.0,
"doc_count": 50
},
"50": {
"key": 50.0,
"doc_count": 50
}
}
}
})
);
Ok(())
}
}

View File

@@ -35,8 +35,6 @@ use crate::{DocId, TantivyError};
/// # Limitations/Compatibility
/// Overlapping ranges are not yet supported.
///
/// The keyed parameter (elasticsearch) is not yet supported.
///
/// # Request JSON Format
/// ```json
/// {
@@ -51,13 +49,16 @@ use crate::{DocId, TantivyError};
/// }
/// }
/// ```
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
pub struct RangeAggregation {
/// The field to aggregate on.
pub field: String,
/// Note that this aggregation includes the from value and excludes the to value for each
/// range. Extra buckets will be created until the first to, and last from, if necessary.
pub ranges: Vec<RangeAggregationRange>,
/// Whether to return the buckets as a hash map
#[serde(default)]
pub keyed: bool,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
@@ -406,6 +407,7 @@ mod tests {
let req = RangeAggregation {
field: "dummy".to_string(),
ranges,
..Default::default()
};
SegmentRangeCollector::from_req_and_validate(
@@ -427,6 +429,7 @@ mod tests {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "fraction_f64".to_string(),
ranges: vec![(0f64..0.1f64).into(), (0.1f64..0.2f64).into()],
..Default::default()
}),
sub_aggregation: Default::default(),
}),
@@ -454,6 +457,49 @@ mod tests {
Ok(())
}
#[test]
fn range_keyed_buckets_test() -> crate::Result<()> {
let index = get_test_index_with_num_docs(false, 100)?;
let agg_req: Aggregations = vec![(
"range".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "fraction_f64".to_string(),
ranges: vec![(0f64..0.1f64).into(), (0.1f64..0.2f64).into()],
keyed: true,
}),
sub_aggregation: Default::default(),
}),
)]
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req, None);
let reader = index.reader()?;
let searcher = reader.searcher();
let agg_res = searcher.search(&AllQuery, &collector).unwrap();
let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?;
assert_eq!(
res,
json!({
"range": {
"buckets": {
"*-0": { "key": "*-0", "doc_count": 0, "to": 0.0},
"0-0.1": {"key": "0-0.1", "doc_count": 10, "from": 0.0, "to": 0.1},
"0.1-0.2": {"key": "0.1-0.2", "doc_count": 10, "from": 0.1, "to": 0.2},
"0.2-*": {"key": "0.2-*", "doc_count": 80, "from": 0.2},
}
}
})
);
Ok(())
}
#[test]
fn bucket_test_extend_range_hole() {
let buckets = vec![(10f64..20f64).into(), (30f64..40f64).into()];

View File

@@ -21,7 +21,7 @@ use super::bucket::{
use super::metric::{IntermediateAverage, IntermediateStats};
use super::segment_agg_result::SegmentMetricResultCollector;
use super::{Key, SerializedKey, VecWithNames};
use crate::aggregation::agg_result::{AggregationResults, BucketEntry};
use crate::aggregation::agg_result::{AggregationResults, BucketEntries, BucketEntry};
use crate::aggregation::bucket::TermsAggregationInternal;
/// Contains the intermediate aggregation result, which is optimized to be merged with other
@@ -281,6 +281,21 @@ impl IntermediateBucketResult {
.unwrap_or(f64::MIN)
.total_cmp(&right.from.unwrap_or(f64::MIN))
});
let is_keyed = req
.as_range()
.expect("unexpected aggregation, expected range aggregation")
.keyed;
let buckets = if is_keyed {
let mut bucket_map =
FnvHashMap::with_capacity_and_hasher(buckets.len(), Default::default());
for bucket in buckets {
bucket_map.insert(bucket.key.to_string(), bucket);
}
BucketEntries::HashMap(bucket_map)
} else {
BucketEntries::Vec(buckets)
};
Ok(BucketResult::Range { buckets })
}
IntermediateBucketResult::Histogram { buckets } => {
@@ -291,6 +306,16 @@ impl IntermediateBucketResult {
&req.sub_aggregation,
)?;
let buckets = if req.as_histogram().unwrap().keyed {
let mut bucket_map =
FnvHashMap::with_capacity_and_hasher(buckets.len(), Default::default());
for bucket in buckets {
bucket_map.insert(bucket.key.to_string(), bucket);
}
BucketEntries::HashMap(bucket_map)
} else {
BucketEntries::Vec(buckets)
};
Ok(BucketResult::Histogram { buckets })
}
IntermediateBucketResult::Terms(terms) => terms.into_final_result(

View File

@@ -285,6 +285,7 @@ mod tests {
(7f64..19f64).into(),
(19f64..20f64).into(),
],
..Default::default()
}),
sub_aggregation: iter::once((
"stats".to_string(),

View File

@@ -132,6 +132,7 @@
//! bucket_agg: BucketAggregationType::Range(RangeAggregation{
//! field: "score".to_string(),
//! ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()],
//! keyed: false,
//! }),
//! sub_aggregation: sub_agg_req_1.clone(),
//! }),
@@ -517,7 +518,7 @@ mod tests {
"histogram": {
"field": "score",
"interval": 70.0,
"offset": 3.0,
"offset": 3.0
},
"aggs": {
"bucketsL2": {
@@ -765,6 +766,7 @@ mod tests {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score".to_string(),
ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()],
..Default::default()
}),
sub_aggregation: Default::default(),
}),
@@ -775,6 +777,7 @@ mod tests {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score_f64".to_string(),
ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()],
..Default::default()
}),
sub_aggregation: Default::default(),
}),
@@ -785,6 +788,7 @@ mod tests {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score_i64".to_string(),
ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()],
..Default::default()
}),
sub_aggregation: Default::default(),
}),
@@ -941,6 +945,7 @@ mod tests {
(7f64..19f64).into(),
(19f64..20f64).into(),
],
..Default::default()
}),
sub_aggregation: sub_agg_req.clone(),
}),
@@ -955,6 +960,7 @@ mod tests {
(7f64..19f64).into(),
(19f64..20f64).into(),
],
..Default::default()
}),
sub_aggregation: sub_agg_req.clone(),
}),
@@ -969,6 +975,7 @@ mod tests {
(7f64..19f64).into(),
(19f64..20f64).into(),
],
..Default::default()
}),
sub_aggregation: sub_agg_req,
}),
@@ -1416,6 +1423,7 @@ mod tests {
(40000f64..50000f64).into(),
(50000f64..60000f64).into(),
],
..Default::default()
}),
sub_aggregation: Default::default(),
}),
@@ -1575,6 +1583,7 @@ mod tests {
(7000f64..20000f64).into(),
(20000f64..60000f64).into(),
],
..Default::default()
}),
sub_aggregation: sub_agg_req_1.clone(),
}),