support order property on term aggregations

support order property on term aggregations
order can be by doc_count, key, or a metric sub_aggregation
This commit is contained in:
Pascal Seitz
2022-04-19 15:14:31 +08:00
parent c7c3eab256
commit 1be6c6111c
10 changed files with 908 additions and 110 deletions

View File

@@ -7,7 +7,6 @@
use std::cmp::Ordering;
use std::collections::HashMap;
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use super::agg_req::{Aggregations, AggregationsInternal, BucketAggregationInternal};
@@ -18,19 +17,36 @@ use super::intermediate_agg_result::{
};
use super::metric::{SingleMetricResult, Stats};
use super::Key;
use crate::TantivyError;
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
/// The final aggegation result.
pub struct AggregationResults(pub HashMap<String, AggregationResult>);
impl AggregationResults {
pub(crate) fn get_value_from_aggregation(
&self,
name: &str,
agg_property: &str,
) -> crate::Result<Option<f64>> {
if let Some(agg) = self.0.get(name) {
agg.get_value_from_aggregation(name, agg_property)
} else {
// Should return an error here? Missing aggregation could be intentional to save
// memory.
// Validation can be done during request parsing
Ok(None)
}
}
/// Convert and intermediate result and its aggregation request to the final result
pub fn from_intermediate_and_req(
results: IntermediateAggregationResults,
agg: Aggregations,
) -> Self {
) -> crate::Result<Self> {
AggregationResults::from_intermediate_and_req_internal(results, &(agg.into()))
}
/// Convert and intermediate result and its aggregation request to the final result
///
/// Internal function, CollectorAggregations is used instead Aggregations, which is optimized
@@ -38,35 +54,40 @@ impl AggregationResults {
pub(crate) fn from_intermediate_and_req_internal(
results: IntermediateAggregationResults,
req: &AggregationsInternal,
) -> Self {
let mut result = HashMap::default();
) -> crate::Result<Self> {
// let mut result = HashMap::default();
// Important assumption:
// When the tree contains buckets/metric, we expect it to have all buckets/metrics from the
// request
if let Some(buckets) = results.buckets {
result.extend(buckets.into_iter().zip(req.buckets.values()).map(
|((key, bucket), req)| {
(
let mut result: HashMap<_, _> = if let Some(buckets) = results.buckets {
buckets
.into_iter()
.zip(req.buckets.values())
.map(|((key, bucket), req)| {
Ok((
key,
AggregationResult::BucketResult(BucketResult::from_intermediate_and_req(
bucket, req,
)),
)
},
));
)?),
))
})
.collect::<crate::Result<HashMap<_, _>>>()?
} else {
result.extend(req.buckets.iter().map(|(key, req)| {
let empty_bucket = IntermediateBucketResult::empty_from_req(&req.bucket_agg);
(
key.to_string(),
AggregationResult::BucketResult(BucketResult::from_intermediate_and_req(
empty_bucket,
req,
)),
)
}));
}
req.buckets
.iter()
.map(|(key, req)| {
let empty_bucket = IntermediateBucketResult::empty_from_req(&req.bucket_agg);
Ok((
key.to_string(),
AggregationResult::BucketResult(BucketResult::from_intermediate_and_req(
empty_bucket,
req,
)?),
))
})
.collect::<crate::Result<HashMap<_, _>>>()?
};
if let Some(metrics) = results.metrics {
result.extend(
@@ -83,7 +104,7 @@ impl AggregationResults {
)
}));
}
Self(result)
Ok(Self(result))
}
}
@@ -97,6 +118,23 @@ pub enum AggregationResult {
MetricResult(MetricResult),
}
impl AggregationResult {
pub(crate) fn get_value_from_aggregation(
&self,
_name: &str,
agg_property: &str,
) -> crate::Result<Option<f64>> {
match self {
AggregationResult::BucketResult(_bucket) => Err(TantivyError::InvalidArgument(
"bucket aggregation not supported to retrieve value, only metrics aggregations \
are supported."
.to_string(),
)),
AggregationResult::MetricResult(metric) => metric.get_value(agg_property),
}
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[serde(untagged)]
/// MetricResult
@@ -107,6 +145,14 @@ pub enum MetricResult {
Stats(Stats),
}
impl MetricResult {
fn get_value(&self, agg_property: &str) -> crate::Result<Option<f64>> {
match self {
MetricResult::Average(avg) => Ok(avg.value),
MetricResult::Stats(stats) => stats.get_value(agg_property),
}
}
}
impl From<IntermediateMetricResult> for MetricResult {
fn from(metric: IntermediateMetricResult) -> Self {
match metric {
@@ -158,32 +204,34 @@ impl BucketResult {
fn from_intermediate_and_req(
bucket_result: IntermediateBucketResult,
req: &BucketAggregationInternal,
) -> Self {
) -> crate::Result<Self> {
match bucket_result {
IntermediateBucketResult::Range(range_map) => {
let mut buckets: Vec<RangeBucketEntry> = range_map
IntermediateBucketResult::Range(range_res) => {
let mut buckets: Vec<RangeBucketEntry> = range_res
.buckets
.into_iter()
.map(|(_, bucket)| {
RangeBucketEntry::from_intermediate_and_req(bucket, &req.sub_aggregation)
})
.collect_vec();
.collect::<crate::Result<Vec<_>>>()?;
buckets.sort_by(|a, b| {
// TODO use total_cmp next stable rust release
a.from
.unwrap_or(f64::MIN)
.partial_cmp(&b.from.unwrap_or(f64::MIN))
.unwrap_or(Ordering::Equal)
});
BucketResult::Range { buckets }
Ok(BucketResult::Range { buckets })
}
IntermediateBucketResult::Histogram { buckets } => {
let buckets = intermediate_buckets_to_final_buckets(
buckets,
req.as_histogram().expect("unexpected aggregation"),
&req.sub_aggregation,
);
)?;
BucketResult::Histogram { buckets }
Ok(BucketResult::Histogram { buckets })
}
IntermediateBucketResult::Terms(terms) => terms.into_final_result(
req.as_term().expect("unexpected aggregation"),
@@ -226,7 +274,7 @@ pub struct BucketEntry {
/// Number of documents in the bucket.
pub doc_count: u64,
#[serde(flatten)]
/// sub-aggregations in this bucket.
/// Sub-aggregations in this bucket.
pub sub_aggregation: AggregationResults,
}
@@ -234,15 +282,20 @@ impl BucketEntry {
pub(crate) fn from_intermediate_and_req(
entry: IntermediateHistogramBucketEntry,
req: &AggregationsInternal,
) -> Self {
BucketEntry {
) -> crate::Result<Self> {
Ok(BucketEntry {
key: Key::F64(entry.key),
doc_count: entry.doc_count,
sub_aggregation: AggregationResults::from_intermediate_and_req_internal(
entry.sub_aggregation,
req,
),
}
)?,
})
}
}
impl GetDocCount for &BucketEntry {
fn doc_count(&self) -> u64 {
self.doc_count
}
}
impl GetDocCount for BucketEntry {
@@ -302,16 +355,16 @@ impl RangeBucketEntry {
fn from_intermediate_and_req(
entry: IntermediateRangeBucketEntry,
req: &AggregationsInternal,
) -> Self {
RangeBucketEntry {
) -> crate::Result<Self> {
Ok(RangeBucketEntry {
key: entry.key,
doc_count: entry.doc_count,
sub_aggregation: AggregationResults::from_intermediate_and_req_internal(
entry.sub_aggregation,
req,
),
)?,
to: entry.to,
from: entry.from,
}
})
}
}

View File

@@ -425,7 +425,7 @@ fn intermediate_buckets_to_final_buckets_fill_gaps(
buckets: Vec<IntermediateHistogramBucketEntry>,
histogram_req: &HistogramAggregation,
sub_aggregation: &AggregationsInternal,
) -> Vec<BucketEntry> {
) -> crate::Result<Vec<BucketEntry>> {
// Generate the the full list of buckets without gaps.
//
// The bounds are the min max from the current buckets, optionally extended by
@@ -468,7 +468,7 @@ fn intermediate_buckets_to_final_buckets_fill_gaps(
.map(|intermediate_bucket| {
BucketEntry::from_intermediate_and_req(intermediate_bucket, sub_aggregation)
})
.collect_vec()
.collect::<crate::Result<Vec<_>>>()
}
// Convert to BucketEntry
@@ -476,7 +476,7 @@ pub(crate) fn intermediate_buckets_to_final_buckets(
buckets: Vec<IntermediateHistogramBucketEntry>,
histogram_req: &HistogramAggregation,
sub_aggregation: &AggregationsInternal,
) -> Vec<BucketEntry> {
) -> crate::Result<Vec<BucketEntry>> {
if histogram_req.min_doc_count() == 0 {
// With min_doc_count != 0, we may need to add buckets, so that there are no
// gaps, since intermediate result does not contain empty buckets (filtered to
@@ -488,7 +488,7 @@ pub(crate) fn intermediate_buckets_to_final_buckets(
.into_iter()
.filter(|bucket| bucket.doc_count >= histogram_req.min_doc_count())
.map(|bucket| BucketEntry::from_intermediate_and_req(bucket, sub_aggregation))
.collect_vec()
.collect::<crate::Result<Vec<_>>>()
}
}

View File

@@ -11,8 +11,124 @@ mod histogram;
mod range;
mod term_agg;
use std::collections::HashMap;
pub(crate) use histogram::SegmentHistogramCollector;
pub use histogram::*;
pub(crate) use range::SegmentRangeCollector;
pub use range::*;
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
pub use term_agg::*;
/// Order for buckets in a bucket aggregation.
#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
pub enum Order {
/// Asc order
#[serde(rename = "asc")]
Asc,
/// Desc order
#[serde(rename = "desc")]
Desc,
}
impl Default for Order {
fn default() -> Self {
Order::Desc
}
}
#[derive(Clone, Debug, PartialEq)]
/// Order property by which to apply the order
pub enum OrderTarget {
/// The key of the bucket
Key,
/// The doc count of the bucket
Count,
/// Order by value of the sub aggregation metric with identified by given `String`.
///
/// Only single value metrics are supported currently
SubAggregation(String),
}
impl Default for OrderTarget {
fn default() -> Self {
OrderTarget::Count
}
}
impl From<&str> for OrderTarget {
fn from(val: &str) -> Self {
match val {
"_key" => OrderTarget::Key,
"_count" => OrderTarget::Count,
_ => OrderTarget::SubAggregation(val.to_string()),
}
}
}
impl ToString for OrderTarget {
fn to_string(&self) -> String {
match self {
OrderTarget::Key => "_key".to_string(),
OrderTarget::Count => "_count".to_string(),
OrderTarget::SubAggregation(agg) => agg.to_string(),
}
}
}
/// Set the order. target is either "_count", "_key", or the name of
/// a metric sub_aggregation.
///
/// De/Serializes to elasticsearch compatible JSON.
///
/// Examples in JSON format:
/// { "_count": "asc" }
/// { "_key": "asc" }
/// { "average_price": "asc" }
#[derive(Clone, Default, Debug, PartialEq)]
pub struct CustomOrder {
/// The target property by which to sort by
pub target: OrderTarget,
/// The order asc or desc
pub order: Order,
}
impl Serialize for CustomOrder {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where S: Serializer {
let map: HashMap<String, Order> =
std::iter::once((self.target.to_string(), self.order)).collect();
map.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for CustomOrder {
fn deserialize<D>(deserializer: D) -> Result<CustomOrder, D::Error>
where D: Deserializer<'de> {
HashMap::<String, Order>::deserialize(deserializer).and_then(|map| {
if let Some((key, value)) = map.into_iter().next() {
Ok(CustomOrder {
target: key.as_str().into(),
order: value,
})
} else {
Err(de::Error::custom(
"unexpected empty map in order".to_string(),
))
}
})
}
}
#[test]
fn custom_order_serde_test() {
let order = CustomOrder {
target: OrderTarget::Key,
order: Order::Desc,
};
let order_str = serde_json::to_string(&order).unwrap();
assert_eq!(order_str, "{\"_key\":\"desc\"}");
let order_deser = serde_json::from_str(&order_str).unwrap();
assert_eq!(order, order_deser);
}

View File

@@ -7,7 +7,7 @@ use crate::aggregation::agg_req_with_accessor::{
AggregationsWithAccessor, BucketAggregationWithAccessor,
};
use crate::aggregation::intermediate_agg_result::{
IntermediateBucketResult, IntermediateRangeBucketEntry,
IntermediateBucketResult, IntermediateRangeBucketEntry, IntermediateRangeBucketResult,
};
use crate::aggregation::segment_agg_result::SegmentAggregationResultsCollector;
use crate::aggregation::{f64_from_fastfield_u64, f64_to_fastfield_u64, Key};
@@ -166,7 +166,9 @@ impl SegmentRangeCollector {
})
.collect::<crate::Result<_>>()?;
Ok(IntermediateBucketResult::Range(buckets))
Ok(IntermediateBucketResult::Range(
IntermediateRangeBucketResult { buckets },
))
}
pub(crate) fn from_req_and_validate(

View File

@@ -1,8 +1,10 @@
use std::fmt::Debug;
use fnv::FnvHashMap;
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use super::{CustomOrder, Order, OrderTarget};
use crate::aggregation::agg_req_with_accessor::{
AggregationsWithAccessor, BucketAggregationWithAccessor,
};
@@ -13,7 +15,7 @@ use crate::aggregation::segment_agg_result::SegmentAggregationResultsCollector;
use crate::error::DataCorruption;
use crate::fastfield::MultiValuedFastFieldReader;
use crate::schema::Type;
use crate::DocId;
use crate::{DocId, TantivyError};
/// Creates a bucket for every unique term
///
@@ -62,7 +64,7 @@ use crate::DocId;
/// }
/// }
/// ```
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
pub struct TermsAggregation {
/// The field to aggregate on.
pub field: String,
@@ -91,30 +93,22 @@ pub struct TermsAggregation {
/// doc_count returned by each shard. Its the sum of the size of the largest bucket on
/// each segment that didnt fit into `shard_size`.
///
/// Defaults to true.
#[serde(default = "default_show_term_doc_count_error")]
pub show_term_doc_count_error: bool,
/// Defaults to true when ordering by counts desc.
pub show_term_doc_count_error: Option<bool>,
/// Filter all terms than are lower `min_doc_count`. Defaults to 1.
///
/// **Expensive**: When set to 0, this will return all terms in the field.
pub min_doc_count: Option<u64>,
}
impl Default for TermsAggregation {
fn default() -> Self {
Self {
field: Default::default(),
size: Default::default(),
shard_size: Default::default(),
show_term_doc_count_error: true,
min_doc_count: Default::default(),
segment_size: Default::default(),
}
}
}
fn default_show_term_doc_count_error() -> bool {
true
/// Set the order. `String` is here a target, which is either "_count", "_key", or the name of
/// a metric sub_aggregation.
///
/// Examples in JSON format:
/// { "_count": "asc" }
/// { "_key": "asc" }
/// { "average_price": "asc" }
pub order: Option<CustomOrder>,
}
/// Same as TermsAggregation, but with populated defaults.
@@ -143,6 +137,8 @@ pub(crate) struct TermsAggregationInternal {
///
/// *Expensive*: When set to 0, this will return all terms in the field.
pub min_doc_count: u64,
pub order: CustomOrder,
}
impl TermsAggregationInternal {
@@ -151,13 +147,17 @@ impl TermsAggregationInternal {
let mut segment_size = req.segment_size.unwrap_or(size * 10);
let order = req.order.clone().unwrap_or_default();
segment_size = segment_size.max(size);
TermsAggregationInternal {
field: req.field.to_string(),
size,
segment_size,
show_term_doc_count_error: req.show_term_doc_count_error,
show_term_doc_count_error: req
.show_term_doc_count_error
.unwrap_or_else(|| order == CustomOrder::default()),
min_doc_count: req.min_doc_count.unwrap_or(1),
order,
}
}
}
@@ -269,6 +269,11 @@ pub struct SegmentTermCollector {
blueprint: Option<SegmentAggregationResultsCollector>,
}
pub(crate) fn get_agg_name_and_property(name: &str) -> (&str, &str) {
let (agg_name, agg_property) = name.split_once('.').unwrap_or((name, ""));
(agg_name, agg_property)
}
impl SegmentTermCollector {
pub(crate) fn from_req_and_validate(
req: &TermsAggregation,
@@ -280,6 +285,19 @@ impl SegmentTermCollector {
let term_buckets =
TermBuckets::from_req_and_validate(sub_aggregations, max_term_id as usize)?;
if let Some(custom_order) = req.order.as_ref() {
if let OrderTarget::SubAggregation(sub_agg_name) = &custom_order.target {
let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name);
sub_aggregations.metrics.get(agg_name).ok_or_else(|| {
TantivyError::InvalidArgument(format!(
"could not find aggregation with name {} in metric sub_aggregations",
agg_name
))
})?;
}
}
let has_sub_aggregations = !sub_aggregations.is_empty();
let blueprint = if has_sub_aggregations {
let sub_aggregation =
@@ -301,10 +319,37 @@ impl SegmentTermCollector {
self,
agg_with_accessor: &BucketAggregationWithAccessor,
) -> crate::Result<IntermediateBucketResult> {
let mut entries: Vec<_> = self.term_buckets.entries.into_iter().collect();
let mut entries: Vec<(u32, TermBucketEntry)> =
self.term_buckets.entries.into_iter().collect();
let (term_doc_count_before_cutoff, sum_other_doc_count) =
cut_off_buckets(&mut entries, self.req.segment_size as usize);
let order_by_key = self.req.order.target == OrderTarget::Key;
let order_by_sub_aggregation =
matches!(self.req.order.target, OrderTarget::SubAggregation(_));
match self.req.order.target {
OrderTarget::Key => {
// defer order and cut_off after loading the texts from the dictionary
}
OrderTarget::SubAggregation(_name) => {
// don't sort of cutt off since it's hard to make assumptions on the quality of the
// results when cutting off, du to unknown nature of the sub_aggregation (possible
// to check).
}
OrderTarget::Count => {
if self.req.order.order == Order::Desc {
entries.sort_unstable_by_key(|bucket| std::cmp::Reverse(bucket.doc_count()));
} else {
entries.sort_unstable_by_key(|bucket| bucket.doc_count());
}
}
}
let (term_doc_count_before_cutoff, mut sum_other_doc_count) =
if order_by_key || order_by_sub_aggregation {
(0, 0)
} else {
cut_off_buckets(&mut entries, self.req.segment_size as usize)
};
let inverted_index = agg_with_accessor
.inverted_index
@@ -335,6 +380,20 @@ impl SegmentTermCollector {
}
}
if order_by_key {
let mut dict_entries = dict.into_iter().collect_vec();
if self.req.order.order == Order::Desc {
dict_entries.sort_unstable_by(|(key1, _), (key2, _)| key1.cmp(key2));
} else {
dict_entries.sort_unstable_by(|(key1, _), (key2, _)| key2.cmp(key1));
}
let (_, sum_other_docs) =
cut_off_buckets(&mut dict_entries, self.req.segment_size as usize);
sum_other_doc_count += sum_other_docs;
dict = dict_entries.into_iter().collect();
}
Ok(IntermediateBucketResult::Terms(
IntermediateTermBucketResult {
entries: dict,
@@ -416,13 +475,16 @@ impl GetDocCount for (u32, TermBucketEntry) {
self.1.doc_count
}
}
impl GetDocCount for (String, IntermediateTermBucketEntry) {
fn doc_count(&self) -> u64 {
self.1.doc_count
}
}
pub(crate) fn cut_off_buckets<T: GetDocCount + Debug>(
entries: &mut Vec<T>,
num_elem: usize,
) -> (u64, u64) {
entries.sort_unstable_by_key(|entry| std::cmp::Reverse(entry.doc_count()));
let term_doc_count_before_cutoff = entries
.get(num_elem)
.map(|entry| entry.doc_count())
@@ -442,10 +504,12 @@ mod tests {
use super::*;
use crate::aggregation::agg_req::{
get_term_dict_field_names, Aggregation, Aggregations, BucketAggregation,
BucketAggregationType,
BucketAggregationType, MetricAggregation,
};
use crate::aggregation::metric::{AverageAggregation, StatsAggregation};
use crate::aggregation::tests::{
exec_request, exec_request_with_query, get_test_index_from_terms,
get_test_index_from_values_and_terms,
};
#[test]
@@ -487,8 +551,8 @@ mod tests {
assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 5);
assert_eq!(res["my_texts"]["buckets"][1]["key"], "termb");
assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 2);
assert_eq!(res["my_texts"]["buckets"][2]["doc_count"], 1);
assert_eq!(res["my_texts"]["buckets"][2]["key"], "termc");
assert_eq!(res["my_texts"]["buckets"][2]["doc_count"], 1);
assert_eq!(res["my_texts"]["sum_other_doc_count"], 0);
let agg_req: Aggregations = vec![(
@@ -550,6 +614,447 @@ mod tests {
Ok(())
}
#[test]
fn terms_aggregation_test_order_count_single_segment() -> crate::Result<()> {
terms_aggregation_test_order_count_merge_segment(true)
}
#[test]
fn terms_aggregation_test_count_order() -> crate::Result<()> {
terms_aggregation_test_order_count_merge_segment(false)
}
fn terms_aggregation_test_order_count_merge_segment(merge_segments: bool) -> crate::Result<()> {
let segment_and_terms = vec![
vec![(5.0, "terma".to_string())],
vec![(4.0, "termb".to_string())],
vec![(1.0, "termc".to_string())],
vec![(1.0, "termc".to_string())],
vec![(1.0, "termc".to_string())],
vec![(5.0, "terma".to_string())],
vec![(5.0, "terma".to_string())],
vec![(5.0, "terma".to_string())],
vec![(8.0, "termb".to_string())],
vec![(5.0, "terma".to_string())],
];
let index = get_test_index_from_values_and_terms(merge_segments, &segment_and_terms)?;
let sub_agg: Aggregations = vec![
(
"avg_score".to_string(),
Aggregation::Metric(MetricAggregation::Average(
AverageAggregation::from_field_name("score".to_string()),
)),
),
(
"stats_score".to_string(),
Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name(
"score".to_string(),
))),
),
]
.into_iter()
.collect();
// sub agg desc
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
order: Some(CustomOrder {
order: Order::Asc,
target: OrderTarget::Count,
}),
..Default::default()
}),
sub_aggregation: sub_agg,
}),
)]
.into_iter()
.collect();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "termb");
assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 2);
assert_eq!(res["my_texts"]["buckets"][1]["key"], "termc");
assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 3);
assert_eq!(res["my_texts"]["buckets"][2]["key"], "terma");
assert_eq!(res["my_texts"]["buckets"][2]["doc_count"], 5);
assert_eq!(res["my_texts"]["sum_other_doc_count"], 0);
Ok(())
}
#[test]
fn terms_aggregation_test_order_sub_agg_single_segment() -> crate::Result<()> {
terms_aggregation_test_order_sub_agg_merge_segment(true)
}
#[test]
fn terms_aggregation_test_sub_agg_order() -> crate::Result<()> {
terms_aggregation_test_order_sub_agg_merge_segment(false)
}
fn terms_aggregation_test_order_sub_agg_merge_segment(
merge_segments: bool,
) -> crate::Result<()> {
let segment_and_terms = vec![
vec![(5.0, "terma".to_string())],
vec![(4.0, "termb".to_string())],
vec![(1.0, "termc".to_string())],
vec![(1.0, "termc".to_string())],
vec![(1.0, "termc".to_string())],
vec![(5.0, "terma".to_string())],
vec![(5.0, "terma".to_string())],
vec![(5.0, "terma".to_string())],
vec![(8.0, "termb".to_string())],
vec![(5.0, "terma".to_string())],
];
let index = get_test_index_from_values_and_terms(merge_segments, &segment_and_terms)?;
let sub_agg: Aggregations = vec![
(
"avg_score".to_string(),
Aggregation::Metric(MetricAggregation::Average(
AverageAggregation::from_field_name("score".to_string()),
)),
),
(
"stats_score".to_string(),
Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name(
"score".to_string(),
))),
),
]
.into_iter()
.collect();
// sub agg desc
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
order: Some(CustomOrder {
order: Order::Desc,
target: OrderTarget::SubAggregation("avg_score".to_string()),
}),
..Default::default()
}),
sub_aggregation: sub_agg.clone(),
}),
)]
.into_iter()
.collect();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "termb");
assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 2);
assert_eq!(res["my_texts"]["buckets"][0]["avg_score"]["value"], 6.0);
assert_eq!(res["my_texts"]["buckets"][1]["key"], "terma");
assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 5);
assert_eq!(res["my_texts"]["buckets"][1]["avg_score"]["value"], 5.0);
assert_eq!(res["my_texts"]["buckets"][2]["key"], "termc");
assert_eq!(res["my_texts"]["buckets"][2]["doc_count"], 3);
assert_eq!(res["my_texts"]["buckets"][2]["avg_score"]["value"], 1.0);
assert_eq!(res["my_texts"]["sum_other_doc_count"], 0);
// sub agg asc
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
order: Some(CustomOrder {
order: Order::Asc,
target: OrderTarget::SubAggregation("avg_score".to_string()),
}),
..Default::default()
}),
sub_aggregation: sub_agg.clone(),
}),
)]
.into_iter()
.collect();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "termc");
assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 3);
assert_eq!(res["my_texts"]["buckets"][0]["avg_score"]["value"], 1.0);
assert_eq!(res["my_texts"]["buckets"][1]["key"], "terma");
assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 5);
assert_eq!(res["my_texts"]["buckets"][1]["avg_score"]["value"], 5.0);
assert_eq!(res["my_texts"]["buckets"][2]["key"], "termb");
assert_eq!(res["my_texts"]["buckets"][2]["doc_count"], 2);
assert_eq!(res["my_texts"]["buckets"][2]["avg_score"]["value"], 6.0);
assert_eq!(res["my_texts"]["sum_other_doc_count"], 0);
// sub agg multi value asc
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
order: Some(CustomOrder {
order: Order::Asc,
target: OrderTarget::SubAggregation("stats_score.avg".to_string()),
}),
..Default::default()
}),
sub_aggregation: sub_agg.clone(),
}),
)]
.into_iter()
.collect();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "termc");
assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 3);
assert_eq!(res["my_texts"]["buckets"][0]["avg_score"]["value"], 1.0);
assert_eq!(res["my_texts"]["buckets"][1]["key"], "terma");
assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 5);
assert_eq!(res["my_texts"]["buckets"][1]["avg_score"]["value"], 5.0);
assert_eq!(res["my_texts"]["buckets"][2]["key"], "termb");
assert_eq!(res["my_texts"]["buckets"][2]["doc_count"], 2);
assert_eq!(res["my_texts"]["buckets"][2]["avg_score"]["value"], 6.0);
assert_eq!(res["my_texts"]["sum_other_doc_count"], 0);
// sub agg invalid request
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
order: Some(CustomOrder {
order: Order::Asc,
target: OrderTarget::SubAggregation("doesnotexist".to_string()),
}),
..Default::default()
}),
sub_aggregation: sub_agg,
}),
)]
.into_iter()
.collect();
let res = exec_request(agg_req, &index);
assert!(res.is_err());
Ok(())
}
#[test]
fn terms_aggregation_test_order_key_single_segment() -> crate::Result<()> {
terms_aggregation_test_order_key_merge_segment(true)
}
#[test]
fn terms_aggregation_test_key_order() -> crate::Result<()> {
terms_aggregation_test_order_key_merge_segment(false)
}
fn terms_aggregation_test_order_key_merge_segment(merge_segments: bool) -> crate::Result<()> {
let segment_and_terms = vec![
vec![(5.0, "terma".to_string())],
vec![(4.0, "termb".to_string())],
vec![(1.0, "termc".to_string())],
vec![(1.0, "termc".to_string())],
vec![(1.0, "termc".to_string())],
vec![(5.0, "terma".to_string())],
vec![(5.0, "terma".to_string())],
vec![(5.0, "terma".to_string())],
vec![(8.0, "termb".to_string())],
vec![(5.0, "terma".to_string())],
];
let index = get_test_index_from_values_and_terms(merge_segments, &segment_and_terms)?;
// key desc
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
order: Some(CustomOrder {
order: Order::Desc,
target: OrderTarget::Key,
}),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
)]
.into_iter()
.collect();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "terma");
assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 5);
assert_eq!(res["my_texts"]["buckets"][1]["key"], "termb");
assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 2);
assert_eq!(res["my_texts"]["buckets"][2]["key"], "termc");
assert_eq!(res["my_texts"]["buckets"][2]["doc_count"], 3);
assert_eq!(res["my_texts"]["sum_other_doc_count"], 0);
// key desc and size cut_off
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
order: Some(CustomOrder {
order: Order::Desc,
target: OrderTarget::Key,
}),
size: Some(2),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
)]
.into_iter()
.collect();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "terma");
assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 5);
assert_eq!(res["my_texts"]["buckets"][1]["key"], "termb");
assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 2);
assert_eq!(
res["my_texts"]["buckets"][2]["doc_count"],
serde_json::Value::Null
);
assert_eq!(res["my_texts"]["sum_other_doc_count"], 3);
// key desc and segment_size cut_off
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
order: Some(CustomOrder {
order: Order::Desc,
target: OrderTarget::Key,
}),
size: Some(2),
segment_size: Some(2),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
)]
.into_iter()
.collect();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "terma");
assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 5);
assert_eq!(res["my_texts"]["buckets"][1]["key"], "termb");
assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 2);
assert_eq!(
res["my_texts"]["buckets"][2]["doc_count"],
serde_json::Value::Null
);
// key asc
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
order: Some(CustomOrder {
order: Order::Asc,
target: OrderTarget::Key,
}),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
)]
.into_iter()
.collect();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "termc");
assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 3);
assert_eq!(res["my_texts"]["buckets"][1]["key"], "termb");
assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 2);
assert_eq!(res["my_texts"]["buckets"][2]["key"], "terma");
assert_eq!(res["my_texts"]["buckets"][2]["doc_count"], 5);
assert_eq!(res["my_texts"]["sum_other_doc_count"], 0);
// key asc, size cut_off
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
order: Some(CustomOrder {
order: Order::Asc,
target: OrderTarget::Key,
}),
size: Some(2),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
)]
.into_iter()
.collect();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "termc");
assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 3);
assert_eq!(res["my_texts"]["buckets"][1]["key"], "termb");
assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 2);
assert_eq!(
res["my_texts"]["buckets"][2]["doc_count"],
serde_json::Value::Null
);
assert_eq!(res["my_texts"]["sum_other_doc_count"], 5);
// key asc, segment_size cut_off
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
order: Some(CustomOrder {
order: Order::Asc,
target: OrderTarget::Key,
}),
size: Some(2),
segment_size: Some(2),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
)]
.into_iter()
.collect();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "termc");
assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 3);
assert_eq!(res["my_texts"]["buckets"][1]["key"], "termb");
assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 2);
assert_eq!(
res["my_texts"]["buckets"][2]["doc_count"],
serde_json::Value::Null
);
Ok(())
}
#[test]
fn terms_aggregation_min_doc_count_special_case() -> crate::Result<()> {
let terms_per_segment = vec![
@@ -627,6 +1132,32 @@ mod tests {
assert_eq!(res["my_texts"]["sum_other_doc_count"], 4);
assert_eq!(res["my_texts"]["doc_count_error_upper_bound"], 2);
// disable doc_count_error_upper_bound
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
size: Some(2),
segment_size: Some(2),
show_term_doc_count_error: Some(false),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
)]
.into_iter()
.collect();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["sum_other_doc_count"], 4);
assert_eq!(
res["my_texts"]["doc_count_error_upper_bound"],
serde_json::Value::Null
);
Ok(())
}
}

View File

@@ -86,8 +86,8 @@ impl Collector for AggregationCollector {
&self,
segment_fruits: Vec<<Self::Child as SegmentCollector>::Fruit>,
) -> crate::Result<Self::Fruit> {
merge_fruits(segment_fruits)
.map(|res| AggregationResults::from_intermediate_and_req(res, self.agg.clone()))
let res = merge_fruits(segment_fruits)?;
AggregationResults::from_intermediate_and_req(res, self.agg.clone())
}
}

View File

@@ -10,7 +10,10 @@ use serde::{Deserialize, Serialize};
use super::agg_req::{AggregationsInternal, BucketAggregationType, MetricAggregation};
use super::agg_result::BucketResult;
use super::bucket::{cut_off_buckets, SegmentHistogramBucketEntry, TermsAggregation};
use super::bucket::{
cut_off_buckets, get_agg_name_and_property, GetDocCount, Order, OrderTarget,
SegmentHistogramBucketEntry, TermsAggregation,
};
use super::metric::{IntermediateAverage, IntermediateStats};
use super::segment_agg_result::SegmentMetricResultCollector;
use super::{Key, SerializedKey, VecWithNames};
@@ -154,7 +157,7 @@ impl IntermediateMetricResult {
pub enum IntermediateBucketResult {
/// This is the range entry for a bucket, which contains a key, count, from, to, and optionally
/// sub_aggregations.
Range(FnvHashMap<SerializedKey, IntermediateRangeBucketEntry>),
Range(IntermediateRangeBucketResult),
/// This is the histogram entry for a bucket, which contains a key, count, and optionally
/// sub_aggregations.
Histogram {
@@ -178,34 +181,34 @@ impl IntermediateBucketResult {
fn merge_fruits(&mut self, other: IntermediateBucketResult) {
match (self, other) {
(
IntermediateBucketResult::Terms(entries_left),
IntermediateBucketResult::Terms(entries_right),
IntermediateBucketResult::Terms(term_res_left),
IntermediateBucketResult::Terms(term_res_right),
) => {
merge_maps(&mut entries_left.entries, entries_right.entries);
entries_left.sum_other_doc_count += entries_right.sum_other_doc_count;
entries_left.doc_count_error_upper_bound +=
entries_right.doc_count_error_upper_bound;
merge_maps(&mut term_res_left.entries, term_res_right.entries);
term_res_left.sum_other_doc_count += term_res_right.sum_other_doc_count;
term_res_left.doc_count_error_upper_bound +=
term_res_right.doc_count_error_upper_bound;
}
(
IntermediateBucketResult::Range(entries_left),
IntermediateBucketResult::Range(entries_right),
IntermediateBucketResult::Range(range_res_left),
IntermediateBucketResult::Range(range_res_right),
) => {
merge_maps(entries_left, entries_right);
merge_maps(&mut range_res_left.buckets, range_res_right.buckets);
}
(
IntermediateBucketResult::Histogram {
buckets: entries_left,
buckets: buckets_left,
..
},
IntermediateBucketResult::Histogram {
buckets: entries_right,
buckets: buckets_right,
..
},
) => {
let mut buckets = entries_left
let mut buckets = buckets_left
.drain(..)
.merge_join_by(entries_right.into_iter(), |left, right| {
.merge_join_by(buckets_right.into_iter(), |left, right| {
left.key.partial_cmp(&right.key).unwrap_or(Ordering::Equal)
})
.map(|either| match either {
@@ -218,7 +221,7 @@ impl IntermediateBucketResult {
})
.collect();
std::mem::swap(entries_left, &mut buckets);
std::mem::swap(buckets_left, &mut buckets);
}
(IntermediateBucketResult::Range(_), _) => {
panic!("try merge on different types")
@@ -233,6 +236,12 @@ impl IntermediateBucketResult {
}
}
#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)]
/// Range aggregation including error counts
pub struct IntermediateRangeBucketResult {
pub(crate) buckets: FnvHashMap<SerializedKey, IntermediateRangeBucketEntry>,
}
#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)]
/// Term aggregation including error counts
pub struct IntermediateTermBucketResult {
@@ -246,22 +255,75 @@ impl IntermediateTermBucketResult {
self,
req: &TermsAggregation,
sub_aggregation_req: &AggregationsInternal,
) -> BucketResult {
) -> crate::Result<BucketResult> {
let req = TermsAggregationInternal::from_req(req);
let mut buckets: Vec<BucketEntry> = self
.entries
.into_iter()
.filter(|bucket| bucket.1.doc_count >= req.min_doc_count)
.map(|(key, entry)| BucketEntry {
key: Key::Str(key),
doc_count: entry.doc_count,
sub_aggregation: AggregationResults::from_intermediate_and_req_internal(
entry.sub_aggregation,
sub_aggregation_req,
),
.map(|(key, entry)| {
Ok(BucketEntry {
key: Key::Str(key),
doc_count: entry.doc_count,
sub_aggregation: AggregationResults::from_intermediate_and_req_internal(
entry.sub_aggregation,
sub_aggregation_req,
)?,
})
})
.collect();
buckets.sort_unstable_by_key(|bucket| std::cmp::Reverse(bucket.doc_count));
.collect::<crate::Result<_>>()?;
let order = req.order.order;
match req.order.target {
OrderTarget::Key => {
buckets.sort_by(|bucket1, bucket2| {
if req.order.order == Order::Desc {
bucket1
.key
.partial_cmp(&bucket2.key)
.expect("expected type string, which is always sortable")
} else {
bucket2
.key
.partial_cmp(&bucket1.key)
.expect("expected type string, which is always sortable")
}
});
}
OrderTarget::Count => {
if req.order.order == Order::Desc {
buckets.sort_unstable_by_key(|bucket| std::cmp::Reverse(bucket.doc_count()));
} else {
buckets.sort_unstable_by_key(|bucket| bucket.doc_count());
}
}
OrderTarget::SubAggregation(name) => {
let (agg_name, agg_property) = get_agg_name_and_property(&name);
let mut buckets_with_val = buckets
.into_iter()
.map(|bucket| {
let val = bucket
.sub_aggregation
.get_value_from_aggregation(agg_name, agg_property)?
.unwrap_or(f64::NAN);
Ok((bucket, val))
})
.collect::<crate::Result<Vec<_>>>()?;
buckets_with_val.sort_by(|(_, val1), (_, val2)| {
// TODO use total_cmp in next rust stable release
match &order {
Order::Desc => val2.partial_cmp(val1).unwrap_or(std::cmp::Ordering::Equal),
Order::Asc => val1.partial_cmp(val2).unwrap_or(std::cmp::Ordering::Equal),
}
});
buckets = buckets_with_val
.into_iter()
.map(|(bucket, _val)| bucket)
.collect_vec();
}
}
// We ignore _term_doc_count_before_cutoff here, because it increases the upperbound error
// only for terms that didn't make it into the top N.
//
@@ -276,11 +338,11 @@ impl IntermediateTermBucketResult {
None
};
BucketResult::Terms {
Ok(BucketResult::Terms {
buckets,
sum_other_doc_count: self.sum_other_doc_count + sum_other_doc_count,
doc_count_error_upper_bound,
}
})
}
}
@@ -399,7 +461,7 @@ mod tests {
}
map.insert(
"my_agg_level2".to_string(),
IntermediateBucketResult::Range(buckets),
IntermediateBucketResult::Range(IntermediateRangeBucketResult { buckets }),
);
IntermediateAggregationResults {
buckets: Some(VecWithNames::from_entries(map.into_iter().collect())),
@@ -429,7 +491,7 @@ mod tests {
}
map.insert(
"my_agg_level1".to_string(),
IntermediateBucketResult::Range(buckets),
IntermediateBucketResult::Range(IntermediateRangeBucketResult { buckets }),
);
IntermediateAggregationResults {
buckets: Some(VecWithNames::from_entries(map.into_iter().collect())),

View File

@@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize};
use crate::aggregation::f64_from_fastfield_u64;
use crate::fastfield::{DynamicFastFieldReader, FastFieldReader};
use crate::schema::Type;
use crate::DocId;
use crate::{DocId, TantivyError};
/// A multi-value metric aggregation that computes stats of numeric values that are
/// extracted from the aggregated documents.
@@ -53,6 +53,23 @@ pub struct Stats {
pub avg: Option<f64>,
}
impl Stats {
pub(crate) fn get_value(&self, agg_property: &str) -> crate::Result<Option<f64>> {
match agg_property {
"count" => Ok(Some(self.count as f64)),
"sum" => Ok(Some(self.sum)),
"standard_deviation" => Ok(self.standard_deviation),
"min" => Ok(self.min),
"max" => Ok(self.max),
"avg" => Ok(self.avg),
_ => Err(TantivyError::InvalidArgument(format!(
"unknown property {} on stats metric aggregation",
agg_property
))),
}
}
}
/// IntermediateStats contains the mergeable version for stats.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct IntermediateStats {

View File

@@ -247,6 +247,11 @@ impl<T: Clone> VecWithNames<T> {
fn is_empty(&self) -> bool {
self.keys.is_empty()
}
fn get(&self, name: &str) -> Option<&T> {
self.keys()
.position(|key| key == name)
.map(|pos| &self.values[pos])
}
}
/// The serialized key is used in a HashMap.
@@ -540,6 +545,7 @@ mod tests {
searcher.search(&AllQuery, &collector).unwrap(),
agg_req,
)
.unwrap()
} else {
let collector = AggregationCollector::from_aggs(agg_req);
@@ -975,7 +981,7 @@ mod tests {
// Test de/serialization roundtrip on intermediate_agg_result
let res: IntermediateAggregationResults =
serde_json::from_str(&serde_json::to_string(&res).unwrap()).unwrap();
AggregationResults::from_intermediate_and_req(res, agg_req.clone())
AggregationResults::from_intermediate_and_req(res, agg_req.clone()).unwrap()
} else {
let collector = AggregationCollector::from_aggs(agg_req.clone());

View File

@@ -29,6 +29,17 @@ pub(crate) struct SegmentAggregationResultsCollector {
num_staged_docs: usize,
}
impl Default for SegmentAggregationResultsCollector {
fn default() -> Self {
Self {
metrics: Default::default(),
buckets: Default::default(),
staged_docs: [0; DOC_BLOCK_SIZE],
num_staged_docs: Default::default(),
}
}
}
impl Debug for SegmentAggregationResultsCollector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SegmentAggregationResultsCollector")
@@ -216,7 +227,7 @@ impl SegmentBucketResultCollector {
req.field_type,
req.accessor
.as_multi()
.expect("unexpected fast field cardinatility"),
.expect("unexpected fast field cardinality"),
)?,
))),
BucketAggregationType::Range(range_req) => {
@@ -233,7 +244,7 @@ impl SegmentBucketResultCollector {
req.field_type,
req.accessor
.as_single()
.expect("unexpected fast field cardinatility"),
.expect("unexpected fast field cardinality"),
)?,
))),
}