mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2026-01-06 17:22:54 +00:00
* add aggregation support for date type fixes #1332 * serialize key_as_string as rfc3339 in date histogram * update docs * enable date for range aggregation
372 lines
11 KiB
Rust
372 lines
11 KiB
Rust
use fastfield_codecs::Column;
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
use crate::aggregation::f64_from_fastfield_u64;
|
|
use crate::schema::Type;
|
|
use crate::{DocId, TantivyError};
|
|
|
|
/// A multi-value metric aggregation that computes stats of numeric values that are
|
|
/// extracted from the aggregated documents.
|
|
/// Supported field types are `u64`, `i64`, and `f64`.
|
|
/// See [`Stats`] for returned statistics.
|
|
///
|
|
/// # JSON Format
|
|
/// ```json
|
|
/// {
|
|
/// "stats": {
|
|
/// "field": "score",
|
|
/// }
|
|
/// }
|
|
/// ```
|
|
|
|
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
|
pub struct StatsAggregation {
|
|
/// The field name to compute the stats on.
|
|
pub field: String,
|
|
}
|
|
|
|
impl StatsAggregation {
|
|
/// Create new StatsAggregation from a field.
|
|
pub fn from_field_name(field_name: String) -> Self {
|
|
StatsAggregation { field: field_name }
|
|
}
|
|
/// Return the field name.
|
|
pub fn field_name(&self) -> &str {
|
|
&self.field
|
|
}
|
|
}
|
|
|
|
/// Stats contains a collection of statistics.
|
|
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
|
pub struct Stats {
|
|
/// The number of documents.
|
|
pub count: usize,
|
|
/// The sum of the fast field values.
|
|
pub sum: f64,
|
|
/// The standard deviation of the fast field values. `None` for count == 0.
|
|
pub standard_deviation: Option<f64>,
|
|
/// The min value of the fast field values.
|
|
pub min: Option<f64>,
|
|
/// The max value of the fast field values.
|
|
pub max: Option<f64>,
|
|
/// The average of the values. `None` for count == 0.
|
|
pub avg: Option<f64>,
|
|
}
|
|
|
|
impl Stats {
|
|
pub(crate) fn get_value(&self, agg_property: &str) -> crate::Result<Option<f64>> {
|
|
match agg_property {
|
|
"count" => Ok(Some(self.count as f64)),
|
|
"sum" => Ok(Some(self.sum)),
|
|
"standard_deviation" => Ok(self.standard_deviation),
|
|
"min" => Ok(self.min),
|
|
"max" => Ok(self.max),
|
|
"avg" => Ok(self.avg),
|
|
_ => Err(TantivyError::InvalidArgument(format!(
|
|
"unknown property {} on stats metric aggregation",
|
|
agg_property
|
|
))),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// `IntermediateStats` contains the mergeable version for stats.
|
|
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
|
pub struct IntermediateStats {
|
|
count: usize,
|
|
sum: f64,
|
|
squared_sum: f64,
|
|
min: f64,
|
|
max: f64,
|
|
}
|
|
impl Default for IntermediateStats {
|
|
fn default() -> Self {
|
|
Self {
|
|
count: 0,
|
|
sum: 0.0,
|
|
squared_sum: 0.0,
|
|
min: f64::MAX,
|
|
max: f64::MIN,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl IntermediateStats {
|
|
pub(crate) fn avg(&self) -> Option<f64> {
|
|
if self.count == 0 {
|
|
None
|
|
} else {
|
|
Some(self.sum / (self.count as f64))
|
|
}
|
|
}
|
|
|
|
fn square_mean(&self) -> f64 {
|
|
self.squared_sum / (self.count as f64)
|
|
}
|
|
|
|
pub(crate) fn standard_deviation(&self) -> Option<f64> {
|
|
self.avg()
|
|
.map(|average| (self.square_mean() - average * average).sqrt())
|
|
}
|
|
|
|
/// Merge data from other stats into this instance.
|
|
pub fn merge_fruits(&mut self, other: IntermediateStats) {
|
|
self.count += other.count;
|
|
self.sum += other.sum;
|
|
self.squared_sum += other.squared_sum;
|
|
self.min = self.min.min(other.min);
|
|
self.max = self.max.max(other.max);
|
|
}
|
|
|
|
/// compute final resultimprove_docs
|
|
pub fn finalize(&self) -> Stats {
|
|
let min = if self.count == 0 {
|
|
None
|
|
} else {
|
|
Some(self.min)
|
|
};
|
|
let max = if self.count == 0 {
|
|
None
|
|
} else {
|
|
Some(self.max)
|
|
};
|
|
Stats {
|
|
count: self.count,
|
|
sum: self.sum,
|
|
standard_deviation: self.standard_deviation(),
|
|
min,
|
|
max,
|
|
avg: self.avg(),
|
|
}
|
|
}
|
|
|
|
#[inline]
|
|
fn collect(&mut self, value: f64) {
|
|
self.count += 1;
|
|
self.sum += value;
|
|
self.squared_sum += value * value;
|
|
self.min = self.min.min(value);
|
|
self.max = self.max.max(value);
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Debug, PartialEq)]
|
|
pub(crate) struct SegmentStatsCollector {
|
|
pub(crate) stats: IntermediateStats,
|
|
field_type: Type,
|
|
}
|
|
|
|
impl SegmentStatsCollector {
|
|
pub fn from_req(field_type: Type) -> Self {
|
|
Self {
|
|
field_type,
|
|
stats: IntermediateStats::default(),
|
|
}
|
|
}
|
|
pub(crate) fn collect_block(&mut self, doc: &[DocId], field: &dyn Column<u64>) {
|
|
let mut iter = doc.chunks_exact(4);
|
|
for docs in iter.by_ref() {
|
|
let val1 = field.get_val(docs[0]);
|
|
let val2 = field.get_val(docs[1]);
|
|
let val3 = field.get_val(docs[2]);
|
|
let val4 = field.get_val(docs[3]);
|
|
let val1 = f64_from_fastfield_u64(val1, &self.field_type);
|
|
let val2 = f64_from_fastfield_u64(val2, &self.field_type);
|
|
let val3 = f64_from_fastfield_u64(val3, &self.field_type);
|
|
let val4 = f64_from_fastfield_u64(val4, &self.field_type);
|
|
self.stats.collect(val1);
|
|
self.stats.collect(val2);
|
|
self.stats.collect(val3);
|
|
self.stats.collect(val4);
|
|
}
|
|
for &doc in iter.remainder() {
|
|
let val = field.get_val(doc);
|
|
let val = f64_from_fastfield_u64(val, &self.field_type);
|
|
self.stats.collect(val);
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
|
|
use std::iter;
|
|
|
|
use serde_json::Value;
|
|
|
|
use crate::aggregation::agg_req::{
|
|
Aggregation, Aggregations, BucketAggregation, BucketAggregationType, MetricAggregation,
|
|
RangeAggregation,
|
|
};
|
|
use crate::aggregation::agg_result::AggregationResults;
|
|
use crate::aggregation::metric::StatsAggregation;
|
|
use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values};
|
|
use crate::aggregation::AggregationCollector;
|
|
use crate::query::{AllQuery, TermQuery};
|
|
use crate::schema::IndexRecordOption;
|
|
use crate::Term;
|
|
|
|
#[test]
|
|
fn test_aggregation_stats_empty_index() -> crate::Result<()> {
|
|
// test index without segments
|
|
let values = vec![];
|
|
|
|
let index = get_test_index_from_values(false, &values)?;
|
|
|
|
let agg_req_1: Aggregations = vec![(
|
|
"stats".to_string(),
|
|
Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name(
|
|
"score".to_string(),
|
|
))),
|
|
)]
|
|
.into_iter()
|
|
.collect();
|
|
|
|
let collector = AggregationCollector::from_aggs(agg_req_1, None, index.schema());
|
|
|
|
let reader = index.reader()?;
|
|
let searcher = reader.searcher();
|
|
let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap();
|
|
|
|
let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?;
|
|
assert_eq!(
|
|
res["stats"],
|
|
json!({
|
|
"avg": Value::Null,
|
|
"count": 0,
|
|
"max": Value::Null,
|
|
"min": Value::Null,
|
|
"standard_deviation": Value::Null,
|
|
"sum": 0.0
|
|
})
|
|
);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn test_aggregation_stats() -> crate::Result<()> {
|
|
let index = get_test_index_2_segments(false)?;
|
|
|
|
let reader = index.reader()?;
|
|
let text_field = reader.searcher().schema().get_field("text").unwrap();
|
|
|
|
let term_query = TermQuery::new(
|
|
Term::from_field_text(text_field, "cool"),
|
|
IndexRecordOption::Basic,
|
|
);
|
|
|
|
let agg_req_1: Aggregations = vec![
|
|
(
|
|
"stats_i64".to_string(),
|
|
Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name(
|
|
"score_i64".to_string(),
|
|
))),
|
|
),
|
|
(
|
|
"stats_f64".to_string(),
|
|
Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name(
|
|
"score_f64".to_string(),
|
|
))),
|
|
),
|
|
(
|
|
"stats".to_string(),
|
|
Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name(
|
|
"score".to_string(),
|
|
))),
|
|
),
|
|
(
|
|
"range".to_string(),
|
|
Aggregation::Bucket(BucketAggregation {
|
|
bucket_agg: BucketAggregationType::Range(RangeAggregation {
|
|
field: "score".to_string(),
|
|
ranges: vec![
|
|
(3f64..7f64).into(),
|
|
(7f64..19f64).into(),
|
|
(19f64..20f64).into(),
|
|
],
|
|
..Default::default()
|
|
}),
|
|
sub_aggregation: iter::once((
|
|
"stats".to_string(),
|
|
Aggregation::Metric(MetricAggregation::Stats(
|
|
StatsAggregation::from_field_name("score".to_string()),
|
|
)),
|
|
))
|
|
.collect(),
|
|
}),
|
|
),
|
|
]
|
|
.into_iter()
|
|
.collect();
|
|
|
|
let collector = AggregationCollector::from_aggs(agg_req_1, None, index.schema());
|
|
|
|
let searcher = reader.searcher();
|
|
let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap();
|
|
|
|
let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?;
|
|
assert_eq!(
|
|
res["stats"],
|
|
json!({
|
|
"avg": 12.142857142857142,
|
|
"count": 7,
|
|
"max": 44.0,
|
|
"min": 1.0,
|
|
"standard_deviation": 13.65313748796613,
|
|
"sum": 85.0
|
|
})
|
|
);
|
|
|
|
assert_eq!(
|
|
res["stats_i64"],
|
|
json!({
|
|
"avg": 12.142857142857142,
|
|
"count": 7,
|
|
"max": 44.0,
|
|
"min": 1.0,
|
|
"standard_deviation": 13.65313748796613,
|
|
"sum": 85.0
|
|
})
|
|
);
|
|
|
|
assert_eq!(
|
|
res["stats_f64"],
|
|
json!({
|
|
"avg": 12.214285714285714,
|
|
"count": 7,
|
|
"max": 44.5,
|
|
"min": 1.0,
|
|
"standard_deviation": 13.819905785437443,
|
|
"sum": 85.5
|
|
})
|
|
);
|
|
|
|
assert_eq!(
|
|
res["range"]["buckets"][2]["stats"],
|
|
json!({
|
|
"avg": 10.666666666666666,
|
|
"count": 3,
|
|
"max": 14.0,
|
|
"min": 7.0,
|
|
"standard_deviation": 2.867441755680877,
|
|
"sum": 32.0
|
|
})
|
|
);
|
|
|
|
assert_eq!(
|
|
res["range"]["buckets"][3]["stats"],
|
|
json!({
|
|
"avg": serde_json::Value::Null,
|
|
"count": 0,
|
|
"max": serde_json::Value::Null,
|
|
"min": serde_json::Value::Null,
|
|
"standard_deviation": serde_json::Value::Null,
|
|
"sum": 0.0,
|
|
})
|
|
);
|
|
|
|
Ok(())
|
|
}
|
|
}
|