mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2026-01-07 09:32:54 +00:00
* Initial impl * Added `Filter` impl in `build_single_agg_segment_collector_with_reader` + Added tests * Added `Filter(FilterBucketResult)` + Made tests work. * Fixed type issues. * Fixed a test. * 8a7a73a: Pass `segment_reader` * Added more tests. * Improved parsing + tests * refactoring * Added more tests. * refactoring: moved parsing code under QueryParser * Use Tantivy syntax instead of ES * Added a sanity check test. * Simplified impl + tests * Added back tests in a more maintable way * nitz. * nitz * implemented very simple fast-path * improved a comment * implemented fast field support * Used `BoundsRange` * Improved fast field impl + tests * Simplified execution. * Fixed exports + nitz * Improved the tests to check to the expected result. * Improved test by checking the whole result JSON * Removed brittle perf checks. * Added efficiency verification tests. * Added one more efficiency check test. * Improved the efficiency tests. * Removed unnecessary parsing code + added direct Query obj * Fixed tests. * Improved tests * Fixed code structure * Fixed lint issues * nitz. * nitz * nitz. * nitz. * nitz. * Added an example * Fixed PR comments. * Applied PR comments + nitz * nitz. * Improved the code. * Fixed a perf issue. * Added batch processing. * Made the example more interesting * Fixed bucket count * Renamed Direct to CustomQuery * Fixed lint issues. * No need for scorer to be an `Option` * nitz * Used BitSet * Added an optimization for AllQuery * Fixed merge issues. * Fixed lint issues. * Added benchmark for FILTER * Removed the Option wrapper. * nitz. * Applied PR comments. * Fixed the AllQuery optimization * Applied PR comments. * feat: used `erased_serde` to allow filter query to be serialized * further improved a comment * Added back tests. * removed an unused method * removed an unused method * Added documentation * nitz. * Added query builder. * Fixed a comment. * Applied PR comments. * Fixed doctest issues. * Added ser/de * Removed bench in test * Fixed a lint issue.
395 lines
13 KiB
Rust
395 lines
13 KiB
Rust
//! Contains the aggregation request tree. Used to build an
|
|
//! [`AggregationCollector`](super::AggregationCollector).
|
|
//!
|
|
//! [`Aggregations`] is the top level entry point to create a request, which is a `HashMap<String,
|
|
//! Aggregation>`.
|
|
//!
|
|
//! Requests are compatible with the json format of elasticsearch.
|
|
//!
|
|
//! # Example
|
|
//!
|
|
//! ```
|
|
//! use tantivy::aggregation::agg_req::Aggregations;
|
|
//!
|
|
//! let elasticsearch_compatible_json_req = r#"
|
|
//! {
|
|
//! "range": {
|
|
//! "range": {
|
|
//! "field": "score",
|
|
//! "ranges": [
|
|
//! { "from": 3.0, "to": 7.0 },
|
|
//! { "from": 7.0, "to": 20.0 }
|
|
//! ]
|
|
//! }
|
|
//! }
|
|
//! }"#;
|
|
//! let _agg_req: Aggregations = serde_json::from_str(elasticsearch_compatible_json_req).unwrap();
|
|
//! ```
|
|
|
|
use std::collections::HashSet;
|
|
|
|
use rustc_hash::FxHashMap;
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
use super::bucket::{
|
|
DateHistogramAggregationReq, FilterAggregation, HistogramAggregation, RangeAggregation,
|
|
TermsAggregation,
|
|
};
|
|
use super::metric::{
|
|
AverageAggregation, CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation,
|
|
MaxAggregation, MinAggregation, PercentilesAggregationReq, StatsAggregation, SumAggregation,
|
|
TopHitsAggregationReq,
|
|
};
|
|
|
|
/// The top-level aggregation request structure, which contains [`Aggregation`] and their user
|
|
/// defined names. It is also used in buckets aggregations to define sub-aggregations.
|
|
///
|
|
/// The key is the user defined name of the aggregation.
|
|
pub type Aggregations = FxHashMap<String, Aggregation>;
|
|
|
|
/// Aggregation request.
|
|
///
|
|
/// An aggregation is either a bucket or a metric.
|
|
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
|
#[serde(try_from = "AggregationForDeserialization")]
|
|
pub struct Aggregation {
|
|
/// The aggregation variant, which can be either a bucket or a metric.
|
|
#[serde(flatten)]
|
|
pub agg: AggregationVariants,
|
|
/// on the document set in the bucket.
|
|
#[serde(rename = "aggs")]
|
|
#[serde(skip_serializing_if = "Aggregations::is_empty")]
|
|
pub sub_aggregation: Aggregations,
|
|
}
|
|
|
|
/// In order to display proper error message, we cannot rely on flattening
|
|
/// the json enum. Instead we introduce an intermediary struct to separate
|
|
/// the aggregation from the subaggregation.
|
|
#[derive(Deserialize)]
|
|
struct AggregationForDeserialization {
|
|
#[serde(flatten)]
|
|
pub aggs_remaining_json: serde_json::Value,
|
|
#[serde(rename = "aggs")]
|
|
#[serde(default)]
|
|
pub sub_aggregation: Aggregations,
|
|
}
|
|
|
|
impl TryFrom<AggregationForDeserialization> for Aggregation {
|
|
type Error = serde_json::Error;
|
|
|
|
fn try_from(value: AggregationForDeserialization) -> serde_json::Result<Self> {
|
|
let AggregationForDeserialization {
|
|
aggs_remaining_json,
|
|
sub_aggregation,
|
|
} = value;
|
|
let agg: AggregationVariants = serde_json::from_value(aggs_remaining_json)?;
|
|
Ok(Aggregation {
|
|
agg,
|
|
sub_aggregation,
|
|
})
|
|
}
|
|
}
|
|
|
|
impl Aggregation {
|
|
pub(crate) fn sub_aggregation(&self) -> &Aggregations {
|
|
&self.sub_aggregation
|
|
}
|
|
|
|
fn get_fast_field_names(&self, fast_field_names: &mut HashSet<String>) {
|
|
fast_field_names.extend(
|
|
self.agg
|
|
.get_fast_field_names()
|
|
.iter()
|
|
.map(|s| s.to_string()),
|
|
);
|
|
fast_field_names.extend(get_fast_field_names(&self.sub_aggregation));
|
|
}
|
|
}
|
|
|
|
/// Extract all fast field names used in the tree.
|
|
pub fn get_fast_field_names(aggs: &Aggregations) -> HashSet<String> {
|
|
let mut fast_field_names = Default::default();
|
|
for el in aggs.values() {
|
|
el.get_fast_field_names(&mut fast_field_names)
|
|
}
|
|
fast_field_names
|
|
}
|
|
|
|
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
|
/// All aggregation types.
|
|
pub enum AggregationVariants {
|
|
// Bucket aggregation types
|
|
/// Put data into buckets of user-defined ranges.
|
|
#[serde(rename = "range")]
|
|
Range(RangeAggregation),
|
|
/// Put data into a histogram.
|
|
#[serde(rename = "histogram")]
|
|
Histogram(HistogramAggregation),
|
|
/// Put data into a date histogram.
|
|
#[serde(rename = "date_histogram")]
|
|
DateHistogram(DateHistogramAggregationReq),
|
|
/// Put data into buckets of terms.
|
|
#[serde(rename = "terms")]
|
|
Terms(TermsAggregation),
|
|
/// Filter documents into a single bucket.
|
|
#[serde(rename = "filter")]
|
|
Filter(FilterAggregation),
|
|
|
|
// Metric aggregation types
|
|
/// Computes the average of the extracted values.
|
|
#[serde(rename = "avg")]
|
|
Average(AverageAggregation),
|
|
/// Counts the number of extracted values.
|
|
#[serde(rename = "value_count")]
|
|
Count(CountAggregation),
|
|
/// Finds the maximum value.
|
|
#[serde(rename = "max")]
|
|
Max(MaxAggregation),
|
|
/// Finds the minimum value.
|
|
#[serde(rename = "min")]
|
|
Min(MinAggregation),
|
|
/// Computes a collection of statistics (`min`, `max`, `sum`, `count`, and `avg`) over the
|
|
/// extracted values.
|
|
#[serde(rename = "stats")]
|
|
Stats(StatsAggregation),
|
|
/// Computes a collection of estended statistics (`min`, `max`, `sum`, `count`, `avg`,
|
|
/// `sum_of_squares`, `variance`, `variance_sampling`, `std_deviation`,
|
|
/// `std_deviation_sampling`) over the extracted values.
|
|
#[serde(rename = "extended_stats")]
|
|
ExtendedStats(ExtendedStatsAggregation),
|
|
/// Computes the sum of the extracted values.
|
|
#[serde(rename = "sum")]
|
|
Sum(SumAggregation),
|
|
/// Computes the sum of the extracted values.
|
|
#[serde(rename = "percentiles")]
|
|
Percentiles(PercentilesAggregationReq),
|
|
/// Finds the top k values matching some order
|
|
#[serde(rename = "top_hits")]
|
|
TopHits(TopHitsAggregationReq),
|
|
/// Computes an estimate of the number of unique values
|
|
#[serde(rename = "cardinality")]
|
|
Cardinality(CardinalityAggregationReq),
|
|
}
|
|
|
|
impl AggregationVariants {
|
|
/// Returns the name of the fields used by the aggregation.
|
|
pub fn get_fast_field_names(&self) -> Vec<&str> {
|
|
match self {
|
|
AggregationVariants::Terms(terms) => vec![terms.field.as_str()],
|
|
AggregationVariants::Range(range) => vec![range.field.as_str()],
|
|
AggregationVariants::Histogram(histogram) => vec![histogram.field.as_str()],
|
|
AggregationVariants::DateHistogram(histogram) => vec![histogram.field.as_str()],
|
|
AggregationVariants::Filter(filter) => filter.get_fast_field_names(),
|
|
AggregationVariants::Average(avg) => vec![avg.field_name()],
|
|
AggregationVariants::Count(count) => vec![count.field_name()],
|
|
AggregationVariants::Max(max) => vec![max.field_name()],
|
|
AggregationVariants::Min(min) => vec![min.field_name()],
|
|
AggregationVariants::Stats(stats) => vec![stats.field_name()],
|
|
AggregationVariants::ExtendedStats(extended_stats) => vec![extended_stats.field_name()],
|
|
AggregationVariants::Sum(sum) => vec![sum.field_name()],
|
|
AggregationVariants::Percentiles(per) => vec![per.field_name()],
|
|
AggregationVariants::TopHits(top_hits) => top_hits.field_names(),
|
|
AggregationVariants::Cardinality(per) => vec![per.field_name()],
|
|
}
|
|
}
|
|
|
|
pub(crate) fn as_range(&self) -> Option<&RangeAggregation> {
|
|
match &self {
|
|
AggregationVariants::Range(range) => Some(range),
|
|
_ => None,
|
|
}
|
|
}
|
|
pub(crate) fn as_histogram(&self) -> crate::Result<Option<HistogramAggregation>> {
|
|
match &self {
|
|
AggregationVariants::Histogram(histogram) => Ok(Some(histogram.clone())),
|
|
AggregationVariants::DateHistogram(histogram) => {
|
|
Ok(Some(histogram.to_histogram_req()?))
|
|
}
|
|
_ => Ok(None),
|
|
}
|
|
}
|
|
pub(crate) fn as_term(&self) -> Option<&TermsAggregation> {
|
|
match &self {
|
|
AggregationVariants::Terms(terms) => Some(terms),
|
|
_ => None,
|
|
}
|
|
}
|
|
pub(crate) fn as_percentile(&self) -> Option<&PercentilesAggregationReq> {
|
|
match &self {
|
|
AggregationVariants::Percentiles(percentile_req) => Some(percentile_req),
|
|
_ => None,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn deser_json_test() {
|
|
let agg_req_json = r#"{
|
|
"price_avg": { "avg": { "field": "price" } },
|
|
"price_count": { "value_count": { "field": "price" } },
|
|
"price_max": { "max": { "field": "price" } },
|
|
"price_min": { "min": { "field": "price" } },
|
|
"price_stats": { "stats": { "field": "price" } },
|
|
"price_sum": { "sum": { "field": "price" } }
|
|
}"#;
|
|
let _agg_req: Aggregations = serde_json::from_str(agg_req_json).unwrap();
|
|
}
|
|
|
|
#[test]
|
|
fn deser_json_test_bucket() {
|
|
let agg_req_json = r#"
|
|
{
|
|
"termagg": {
|
|
"terms": {
|
|
"field": "json.mixed_type",
|
|
"order": { "min_price": "desc" }
|
|
},
|
|
"aggs": {
|
|
"min_price": { "min": { "field": "json.mixed_type" } }
|
|
}
|
|
},
|
|
"rangeagg": {
|
|
"range": {
|
|
"field": "json.mixed_type",
|
|
"ranges": [
|
|
{ "to": 3.0 },
|
|
{ "from": 19.0, "to": 20.0 },
|
|
{ "from": 20.0 }
|
|
]
|
|
},
|
|
"aggs": {
|
|
"average_in_range": { "avg": { "field": "json.mixed_type" } }
|
|
}
|
|
}
|
|
} "#;
|
|
|
|
let _agg_req: Aggregations = serde_json::from_str(agg_req_json).unwrap();
|
|
}
|
|
|
|
#[test]
|
|
fn test_metric_aggregations_deser() {
|
|
let agg_req_json = r#"{
|
|
"price_avg": { "avg": { "field": "price" } },
|
|
"price_count": { "value_count": { "field": "price" } },
|
|
"price_max": { "max": { "field": "price" } },
|
|
"price_min": { "min": { "field": "price" } },
|
|
"price_stats": { "stats": { "field": "price" } },
|
|
"price_sum": { "sum": { "field": "price" } }
|
|
}"#;
|
|
let agg_req: Aggregations = serde_json::from_str(agg_req_json).unwrap();
|
|
|
|
assert!(
|
|
matches!(&agg_req.get("price_avg").unwrap().agg, AggregationVariants::Average(avg) if avg.field == "price")
|
|
);
|
|
assert!(
|
|
matches!(&agg_req.get("price_count").unwrap().agg, AggregationVariants::Count(count) if count.field == "price")
|
|
);
|
|
assert!(
|
|
matches!(&agg_req.get("price_max").unwrap().agg, AggregationVariants::Max(max) if max.field == "price")
|
|
);
|
|
assert!(
|
|
matches!(&agg_req.get("price_min").unwrap().agg, AggregationVariants::Min(min) if min.field == "price")
|
|
);
|
|
assert!(
|
|
matches!(&agg_req.get("price_stats").unwrap().agg, AggregationVariants::Stats(stats) if stats.field == "price")
|
|
);
|
|
assert!(
|
|
matches!(&agg_req.get("price_sum").unwrap().agg, AggregationVariants::Sum(sum) if sum.field == "price")
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn serialize_to_json_test() {
|
|
let elasticsearch_compatible_json_req = r#"{
|
|
"range": {
|
|
"range": {
|
|
"field": "score",
|
|
"ranges": [
|
|
{
|
|
"to": 3.0
|
|
},
|
|
{
|
|
"from": 3.0,
|
|
"to": 7.0
|
|
},
|
|
{
|
|
"from": 7.0,
|
|
"to": 20.0
|
|
},
|
|
{
|
|
"from": 20.0
|
|
}
|
|
],
|
|
"keyed": true
|
|
}
|
|
}
|
|
}"#;
|
|
|
|
let agg_req1: Aggregations =
|
|
{ serde_json::from_str(elasticsearch_compatible_json_req).unwrap() };
|
|
|
|
let agg_req2: String = serde_json::to_string_pretty(&agg_req1).unwrap();
|
|
assert_eq!(agg_req2, elasticsearch_compatible_json_req);
|
|
}
|
|
|
|
#[test]
|
|
fn test_get_fast_field_names() {
|
|
let range_agg: Aggregation = {
|
|
serde_json::from_value(json!({
|
|
"range": {
|
|
"field": "score",
|
|
"ranges": [
|
|
{ "to": 3.0 },
|
|
{ "from": 3.0, "to": 7.0 },
|
|
{ "from": 7.0, "to": 20.0 },
|
|
{ "from": 20.0 }
|
|
],
|
|
}
|
|
|
|
}))
|
|
.unwrap()
|
|
};
|
|
|
|
let agg_req1: Aggregations = {
|
|
serde_json::from_value(json!({
|
|
"range1": range_agg,
|
|
"range2":{
|
|
"range": {
|
|
"field": "score2",
|
|
"ranges": [
|
|
{ "to": 3.0 },
|
|
{ "from": 3.0, "to": 7.0 },
|
|
{ "from": 7.0, "to": 20.0 },
|
|
{ "from": 20.0 }
|
|
],
|
|
},
|
|
"aggs": {
|
|
"metric": {
|
|
"avg": {
|
|
"field": "field123"
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}))
|
|
.unwrap()
|
|
};
|
|
|
|
assert_eq!(
|
|
get_fast_field_names(&agg_req1),
|
|
vec![
|
|
"score".to_string(),
|
|
"score2".to_string(),
|
|
"field123".to_string()
|
|
]
|
|
.into_iter()
|
|
.collect()
|
|
)
|
|
}
|
|
}
|