mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2026-01-06 17:22:54 +00:00
agg: support to deserialize f64 from string (#2311)
* agg: support to deserialize f64 from string * remove visit_string * disallow NaN
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::{HistogramAggregation, HistogramBounds};
|
||||
use crate::aggregation::AggregationError;
|
||||
use crate::aggregation::*;
|
||||
|
||||
/// DateHistogramAggregation is similar to `HistogramAggregation`, but it can only be used with date
|
||||
/// type.
|
||||
|
||||
@@ -20,7 +20,7 @@ use crate::aggregation::intermediate_agg_result::{
|
||||
use crate::aggregation::segment_agg_result::{
|
||||
build_segment_agg_collector, AggregationLimits, SegmentAggregationCollector,
|
||||
};
|
||||
use crate::aggregation::{f64_from_fastfield_u64, format_date};
|
||||
use crate::aggregation::*;
|
||||
use crate::TantivyError;
|
||||
|
||||
/// Histogram is a bucket aggregation, where buckets are created dynamically for given `interval`.
|
||||
@@ -73,6 +73,7 @@ pub struct HistogramAggregation {
|
||||
pub field: String,
|
||||
/// The interval to chunk your data range. Each bucket spans a value range of [0..interval).
|
||||
/// Must be a positive value.
|
||||
#[serde(deserialize_with = "deserialize_f64")]
|
||||
pub interval: f64,
|
||||
/// Intervals implicitly defines an absolute grid of buckets `[interval * k, interval * (k +
|
||||
/// 1))`.
|
||||
@@ -85,6 +86,7 @@ pub struct HistogramAggregation {
|
||||
/// fall into the buckets with the key 0 and 10.
|
||||
/// With offset 5 and interval 10, they would both fall into the bucket with they key 5 and the
|
||||
/// range [5..15)
|
||||
#[serde(default, deserialize_with = "deserialize_option_f64")]
|
||||
pub offset: Option<f64>,
|
||||
/// The minimum number of documents in a bucket to be returned. Defaults to 0.
|
||||
pub min_doc_count: Option<u64>,
|
||||
|
||||
@@ -14,9 +14,7 @@ use crate::aggregation::intermediate_agg_result::{
|
||||
use crate::aggregation::segment_agg_result::{
|
||||
build_segment_agg_collector, SegmentAggregationCollector,
|
||||
};
|
||||
use crate::aggregation::{
|
||||
f64_from_fastfield_u64, f64_to_fastfield_u64, format_date, Key, SerializedKey,
|
||||
};
|
||||
use crate::aggregation::*;
|
||||
use crate::TantivyError;
|
||||
|
||||
/// Provide user-defined buckets to aggregate on.
|
||||
@@ -72,11 +70,19 @@ pub struct RangeAggregationRange {
|
||||
pub key: Option<String>,
|
||||
/// The from range value, which is inclusive in the range.
|
||||
/// `None` equals to an open ended interval.
|
||||
#[serde(skip_serializing_if = "Option::is_none", default)]
|
||||
#[serde(
|
||||
skip_serializing_if = "Option::is_none",
|
||||
default,
|
||||
deserialize_with = "deserialize_option_f64"
|
||||
)]
|
||||
pub from: Option<f64>,
|
||||
/// The to range value, which is not inclusive in the range.
|
||||
/// `None` equals to an open ended interval.
|
||||
#[serde(skip_serializing_if = "Option::is_none", default)]
|
||||
#[serde(
|
||||
skip_serializing_if = "Option::is_none",
|
||||
default,
|
||||
deserialize_with = "deserialize_option_f64"
|
||||
)]
|
||||
pub to: Option<f64>,
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,8 @@ use std::fmt::Debug;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::{IntermediateStats, SegmentStatsCollector};
|
||||
use super::*;
|
||||
use crate::aggregation::*;
|
||||
|
||||
/// A single-value metric aggregation that computes the average of numeric values that are
|
||||
/// extracted from the aggregated documents.
|
||||
@@ -24,7 +25,7 @@ pub struct AverageAggregation {
|
||||
/// By default they will be ignored but it is also possible to treat them as if they had a
|
||||
/// value. Examples in JSON format:
|
||||
/// { "field": "my_numbers", "missing": "10.0" }
|
||||
#[serde(default)]
|
||||
#[serde(default, deserialize_with = "deserialize_option_f64")]
|
||||
pub missing: Option<f64>,
|
||||
}
|
||||
|
||||
@@ -65,3 +66,71 @@ impl IntermediateAverage {
|
||||
self.stats.finalize().avg
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn deserialization_with_missing_test1() {
|
||||
let json = r#"{
|
||||
"field": "score",
|
||||
"missing": "10.0"
|
||||
}"#;
|
||||
let avg: AverageAggregation = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(avg.field, "score");
|
||||
assert_eq!(avg.missing, Some(10.0));
|
||||
// no dot
|
||||
let json = r#"{
|
||||
"field": "score",
|
||||
"missing": "10"
|
||||
}"#;
|
||||
let avg: AverageAggregation = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(avg.field, "score");
|
||||
assert_eq!(avg.missing, Some(10.0));
|
||||
|
||||
// from value
|
||||
let avg: AverageAggregation = serde_json::from_value(json!({
|
||||
"field": "score_f64",
|
||||
"missing": 10u64,
|
||||
}))
|
||||
.unwrap();
|
||||
assert_eq!(avg.missing, Some(10.0));
|
||||
// from value
|
||||
let avg: AverageAggregation = serde_json::from_value(json!({
|
||||
"field": "score_f64",
|
||||
"missing": 10u32,
|
||||
}))
|
||||
.unwrap();
|
||||
assert_eq!(avg.missing, Some(10.0));
|
||||
let avg: AverageAggregation = serde_json::from_value(json!({
|
||||
"field": "score_f64",
|
||||
"missing": 10i8,
|
||||
}))
|
||||
.unwrap();
|
||||
assert_eq!(avg.missing, Some(10.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialization_with_missing_test_fail() {
|
||||
let json = r#"{
|
||||
"field": "score",
|
||||
"missing": "a"
|
||||
}"#;
|
||||
let avg: Result<AverageAggregation, _> = serde_json::from_str(json);
|
||||
assert!(avg.is_err());
|
||||
assert!(avg
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("Failed to parse f64 from string: \"a\""));
|
||||
|
||||
// Disallow NaN
|
||||
let json = r#"{
|
||||
"field": "score",
|
||||
"missing": "NaN"
|
||||
}"#;
|
||||
let avg: Result<AverageAggregation, _> = serde_json::from_str(json);
|
||||
assert!(avg.is_err());
|
||||
assert!(avg.unwrap_err().to_string().contains("NaN"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,8 @@ use std::fmt::Debug;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::{IntermediateStats, SegmentStatsCollector};
|
||||
use super::*;
|
||||
use crate::aggregation::*;
|
||||
|
||||
/// A single-value metric aggregation that counts the number of values that are
|
||||
/// extracted from the aggregated documents.
|
||||
@@ -24,7 +25,7 @@ pub struct CountAggregation {
|
||||
/// By default they will be ignored but it is also possible to treat them as if they had a
|
||||
/// value. Examples in JSON format:
|
||||
/// { "field": "my_numbers", "missing": "10.0" }
|
||||
#[serde(default)]
|
||||
#[serde(default, deserialize_with = "deserialize_option_f64")]
|
||||
pub missing: Option<f64>,
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,8 @@ use std::fmt::Debug;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::{IntermediateStats, SegmentStatsCollector};
|
||||
use super::*;
|
||||
use crate::aggregation::*;
|
||||
|
||||
/// A single-value metric aggregation that computes the maximum of numeric values that are
|
||||
/// extracted from the aggregated documents.
|
||||
@@ -24,7 +25,7 @@ pub struct MaxAggregation {
|
||||
/// By default they will be ignored but it is also possible to treat them as if they had a
|
||||
/// value. Examples in JSON format:
|
||||
/// { "field": "my_numbers", "missing": "10.0" }
|
||||
#[serde(default)]
|
||||
#[serde(default, deserialize_with = "deserialize_option_f64")]
|
||||
pub missing: Option<f64>,
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,8 @@ use std::fmt::Debug;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::{IntermediateStats, SegmentStatsCollector};
|
||||
use super::*;
|
||||
use crate::aggregation::*;
|
||||
|
||||
/// A single-value metric aggregation that computes the minimum of numeric values that are
|
||||
/// extracted from the aggregated documents.
|
||||
@@ -24,7 +25,7 @@ pub struct MinAggregation {
|
||||
/// By default they will be ignored but it is also possible to treat them as if they had a
|
||||
/// value. Examples in JSON format:
|
||||
/// { "field": "my_numbers", "missing": "10.0" }
|
||||
#[serde(default)]
|
||||
#[serde(default, deserialize_with = "deserialize_option_f64")]
|
||||
pub missing: Option<f64>,
|
||||
}
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ mod percentiles;
|
||||
mod stats;
|
||||
mod sum;
|
||||
mod top_hits;
|
||||
|
||||
pub use average::*;
|
||||
pub use count::*;
|
||||
pub use max::*;
|
||||
|
||||
@@ -11,7 +11,7 @@ use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::{f64_from_fastfield_u64, f64_to_fastfield_u64, AggregationError};
|
||||
use crate::aggregation::*;
|
||||
use crate::{DocId, TantivyError};
|
||||
|
||||
/// # Percentiles
|
||||
@@ -84,7 +84,11 @@ pub struct PercentilesAggregationReq {
|
||||
/// By default they will be ignored but it is also possible to treat them as if they had a
|
||||
/// value. Examples in JSON format:
|
||||
/// { "field": "my_numbers", "missing": "10.0" }
|
||||
#[serde(skip_serializing_if = "Option::is_none", default)]
|
||||
#[serde(
|
||||
skip_serializing_if = "Option::is_none",
|
||||
default,
|
||||
deserialize_with = "deserialize_option_f64"
|
||||
)]
|
||||
pub missing: Option<f64>,
|
||||
}
|
||||
fn default_percentiles() -> &'static [f64] {
|
||||
|
||||
@@ -9,7 +9,7 @@ use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::{f64_from_fastfield_u64, f64_to_fastfield_u64};
|
||||
use crate::aggregation::*;
|
||||
use crate::{DocId, TantivyError};
|
||||
|
||||
/// A multi-value metric aggregation that computes a collection of statistics on numeric values that
|
||||
@@ -33,7 +33,7 @@ pub struct StatsAggregation {
|
||||
/// By default they will be ignored but it is also possible to treat them as if they had a
|
||||
/// value. Examples in JSON format:
|
||||
/// { "field": "my_numbers", "missing": "10.0" }
|
||||
#[serde(default)]
|
||||
#[serde(default, deserialize_with = "deserialize_option_f64")]
|
||||
pub missing: Option<f64>,
|
||||
}
|
||||
|
||||
@@ -580,6 +580,30 @@ mod tests {
|
||||
})
|
||||
);
|
||||
|
||||
// From string
|
||||
let agg_req: Aggregations = serde_json::from_value(json!({
|
||||
"my_stats": {
|
||||
"stats": {
|
||||
"field": "json.partially_empty",
|
||||
"missing": "0.0"
|
||||
},
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let res = exec_request_with_query(agg_req, &index, None)?;
|
||||
|
||||
assert_eq!(
|
||||
res["my_stats"],
|
||||
json!({
|
||||
"avg": 2.5,
|
||||
"count": 4,
|
||||
"max": 10.0,
|
||||
"min": 0.0,
|
||||
"sum": 10.0
|
||||
})
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,8 @@ use std::fmt::Debug;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::{IntermediateStats, SegmentStatsCollector};
|
||||
use super::*;
|
||||
use crate::aggregation::*;
|
||||
|
||||
/// A single-value metric aggregation that sums up numeric values that are
|
||||
/// extracted from the aggregated documents.
|
||||
@@ -24,7 +25,7 @@ pub struct SumAggregation {
|
||||
/// By default they will be ignored but it is also possible to treat them as if they had a
|
||||
/// value. Examples in JSON format:
|
||||
/// { "field": "my_numbers", "missing": "10.0" }
|
||||
#[serde(default)]
|
||||
#[serde(default, deserialize_with = "deserialize_option_f64")]
|
||||
pub missing: Option<f64>,
|
||||
}
|
||||
|
||||
|
||||
@@ -145,6 +145,8 @@ mod agg_tests;
|
||||
|
||||
mod agg_bench;
|
||||
|
||||
use core::fmt;
|
||||
|
||||
pub use agg_limits::AggregationLimits;
|
||||
pub use collector::{
|
||||
AggregationCollector, AggregationSegmentCollector, DistributedAggregationCollector,
|
||||
@@ -154,7 +156,106 @@ use columnar::{ColumnType, MonotonicallyMappableToU64};
|
||||
pub(crate) use date::format_date;
|
||||
pub use error::AggregationError;
|
||||
use itertools::Itertools;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::de::{self, Visitor};
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
|
||||
fn parse_str_into_f64<E: de::Error>(value: &str) -> Result<f64, E> {
|
||||
let parsed = value.parse::<f64>().map_err(|_err| {
|
||||
de::Error::custom(format!("Failed to parse f64 from string: {:?}", value))
|
||||
})?;
|
||||
|
||||
// Check if the parsed value is NaN or infinity
|
||||
if parsed.is_nan() || parsed.is_infinite() {
|
||||
Err(de::Error::custom(format!(
|
||||
"Value is not a valid f64 (NaN or Infinity): {:?}",
|
||||
value
|
||||
)))
|
||||
} else {
|
||||
Ok(parsed)
|
||||
}
|
||||
}
|
||||
|
||||
/// deserialize Option<f64> from string or float
|
||||
pub(crate) fn deserialize_option_f64<'de, D>(deserializer: D) -> Result<Option<f64>, D::Error>
|
||||
where D: Deserializer<'de> {
|
||||
struct StringOrFloatVisitor;
|
||||
|
||||
impl<'de> Visitor<'de> for StringOrFloatVisitor {
|
||||
type Value = Option<f64>;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
formatter.write_str("a string or a float")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
|
||||
where E: de::Error {
|
||||
parse_str_into_f64(value).map(Some)
|
||||
}
|
||||
|
||||
fn visit_f64<E>(self, value: f64) -> Result<Self::Value, E>
|
||||
where E: de::Error {
|
||||
Ok(Some(value))
|
||||
}
|
||||
|
||||
fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
|
||||
where E: de::Error {
|
||||
Ok(Some(value as f64))
|
||||
}
|
||||
|
||||
fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
|
||||
where E: de::Error {
|
||||
Ok(Some(value as f64))
|
||||
}
|
||||
|
||||
fn visit_none<E>(self) -> Result<Self::Value, E>
|
||||
where E: de::Error {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn visit_unit<E>(self) -> Result<Self::Value, E>
|
||||
where E: de::Error {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_any(StringOrFloatVisitor)
|
||||
}
|
||||
|
||||
/// deserialize f64 from string or float
|
||||
pub(crate) fn deserialize_f64<'de, D>(deserializer: D) -> Result<f64, D::Error>
|
||||
where D: Deserializer<'de> {
|
||||
struct StringOrFloatVisitor;
|
||||
|
||||
impl<'de> Visitor<'de> for StringOrFloatVisitor {
|
||||
type Value = f64;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
formatter.write_str("a string or a float")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
|
||||
where E: de::Error {
|
||||
parse_str_into_f64(value)
|
||||
}
|
||||
|
||||
fn visit_f64<E>(self, value: f64) -> Result<Self::Value, E>
|
||||
where E: de::Error {
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
|
||||
where E: de::Error {
|
||||
Ok(value as f64)
|
||||
}
|
||||
|
||||
fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
|
||||
where E: de::Error {
|
||||
Ok(value as f64)
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_any(StringOrFloatVisitor)
|
||||
}
|
||||
|
||||
/// Represents an associative array `(key => values)` in a very efficient manner.
|
||||
#[derive(PartialEq, Serialize, Deserialize)]
|
||||
|
||||
Reference in New Issue
Block a user