diff --git a/src/aggregation/agg_req.rs b/src/aggregation/agg_req.rs index ca6ff48a9..8e3599108 100644 --- a/src/aggregation/agg_req.rs +++ b/src/aggregation/agg_req.rs @@ -44,22 +44,49 @@ use super::metric::{ /// The key is the user defined name of the aggregation. pub type Aggregations = HashMap; -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] /// Aggregation request. /// /// An aggregation is either a bucket or a metric. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[serde(try_from = "AggregationForDeserialization")] pub struct Aggregation { /// The aggregation variant, which can be either a bucket or a metric. #[serde(flatten)] pub agg: AggregationVariants, - /// The sub_aggregations, only valid for bucket type aggregations. Each bucket will aggregate /// on the document set in the bucket. #[serde(rename = "aggs")] - #[serde(default)] #[serde(skip_serializing_if = "Aggregations::is_empty")] pub sub_aggregation: Aggregations, } +/// In order to display proper error message, we cannot rely on flattening +/// the json enum. Instead we introduce an intermediary struct to separate +/// the aggregation from the subaggregation. +#[derive(Deserialize)] +struct AggregationForDeserialization { + #[serde(flatten)] + pub aggs_remaining_json: serde_json::Value, + #[serde(rename = "aggs")] + #[serde(default)] + pub sub_aggregation: Aggregations, +} + +impl TryFrom for Aggregation { + type Error = serde_json::Error; + + fn try_from(value: AggregationForDeserialization) -> serde_json::Result { + let AggregationForDeserialization { + aggs_remaining_json, + sub_aggregation, + } = value; + let agg: AggregationVariants = serde_json::from_value(aggs_remaining_json)?; + Ok(Aggregation { + agg, + sub_aggregation, + }) + } +} + impl Aggregation { pub(crate) fn sub_aggregation(&self) -> &Aggregations { &self.sub_aggregation diff --git a/src/aggregation/agg_tests.rs b/src/aggregation/agg_tests.rs index 66aff3a4c..ce7a21212 100644 --- a/src/aggregation/agg_tests.rs +++ b/src/aggregation/agg_tests.rs @@ -558,10 +558,10 @@ fn test_aggregation_invalid_requests() -> crate::Result<()> { assert_eq!(agg_req_1.is_err(), true); // TODO: This should list valid values - assert_eq!( - agg_req_1.unwrap_err().to_string(), - "no variant of enum AggregationVariants found in flattened data" - ); + assert!(agg_req_1 + .unwrap_err() + .to_string() + .contains("unknown variant `doesnotmatchanyagg`, expected one of")); // TODO: This should return an error // let agg_res = avg_on_field("not_exist_field").unwrap_err(); diff --git a/src/query/boost_query.rs b/src/query/boost_query.rs index 1b6be6031..85397243a 100644 --- a/src/query/boost_query.rs +++ b/src/query/boost_query.rs @@ -2,7 +2,6 @@ use std::fmt; use crate::docset::BUFFER_LEN; use crate::fastfield::AliveBitSet; -use crate::query::explanation::does_not_match; use crate::query::{EnableScoring, Explanation, Query, Scorer, Weight}; use crate::{DocId, DocSet, Score, SegmentReader, Term};