From 972cb6c26d002b2f83548855fe8cca0b98bf93d3 Mon Sep 17 00:00:00 2001 From: PSeitz Date: Mon, 21 Feb 2022 01:59:11 +0100 Subject: [PATCH] Aggregation (#1276) Added support for aggregation compatible with Elasticsearch's API. --- CHANGELOG.md | 2 + Cargo.toml | 10 +- examples/aggregation.rs | 130 +++ src/aggregation/README.md | 36 + src/aggregation/agg_req.rs | 169 +++ src/aggregation/agg_req_with_accessor.rs | 140 +++ src/aggregation/agg_result.rs | 142 +++ src/aggregation/bucket/mod.rs | 10 + src/aggregation/bucket/range.rs | 536 +++++++++ src/aggregation/collector.rs | 135 +++ src/aggregation/intermediate_agg_result.rs | 304 ++++++ src/aggregation/metric/average.rs | 101 ++ src/aggregation/metric/mod.rs | 22 + src/aggregation/metric/stats.rs | 273 +++++ src/aggregation/mod.rs | 1148 ++++++++++++++++++++ src/aggregation/segment_agg_result.rs | 195 ++++ src/collector/histogram_collector.rs | 2 +- src/error.rs | 3 + src/fastfield/bytes/mod.rs | 4 +- src/fastfield/reader.rs | 3 + src/lib.rs | 1 + 21 files changed, 3354 insertions(+), 12 deletions(-) create mode 100644 examples/aggregation.rs create mode 100644 src/aggregation/README.md create mode 100644 src/aggregation/agg_req.rs create mode 100644 src/aggregation/agg_req_with_accessor.rs create mode 100644 src/aggregation/agg_result.rs create mode 100644 src/aggregation/bucket/mod.rs create mode 100644 src/aggregation/bucket/range.rs create mode 100644 src/aggregation/collector.rs create mode 100644 src/aggregation/intermediate_agg_result.rs create mode 100644 src/aggregation/metric/average.rs create mode 100644 src/aggregation/metric/mod.rs create mode 100644 src/aggregation/metric/stats.rs create mode 100644 src/aggregation/mod.rs create mode 100644 src/aggregation/segment_agg_result.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d9abb667..8f59742c2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ Tantivy 0.17 - Bugfix that could in theory impact durability in theory on some filesystems [#1224](https://github.com/quickwit-oss/tantivy/issues/1224) - Schema now offers not indexing fieldnorms (@lpouget) [#922](https://github.com/quickwit-oss/tantivy/issues/922) - Reduce the number of fsync calls [#1225](https://github.com/quickwit-oss/tantivy/issues/1225) +- Fix opening bytes index with dynamic codec (@PSeitz) [#1278](https://github.com/quickwit-oss/tantivy/issues/1278) +- Added an aggregation collector compatible with Elasticsearch (@PSeitz) Tantivy 0.16.2 ================================ diff --git a/Cargo.toml b/Cargo.toml index 62facaa9a..7bda25725 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -55,6 +55,7 @@ lru = "0.7.0" fastdivide = "0.4" itertools = "0.10.0" measure_time = "0.8.0" +pretty_assertions = "1.1.0" [target.'cfg(windows)'.dependencies] winapi = "0.3.9" @@ -78,11 +79,6 @@ opt-level = 3 debug = false debug-assertions = false -[profile.bench] -opt-level = 3 -debug = true -debug-assertions = false - [profile.test] debug-assertions = true overflow-checks = true @@ -116,7 +112,3 @@ required-features = ["fail/failpoints"] [[bench]] name = "analyzer" harness = false - -[[bench]] -name = "index-bench" -harness = false diff --git a/examples/aggregation.rs b/examples/aggregation.rs new file mode 100644 index 000000000..084b125a1 --- /dev/null +++ b/examples/aggregation.rs @@ -0,0 +1,130 @@ +// # Aggregation example +// +// This example shows how you can use built-in aggregations. +// We will use range buckets and compute the average in each bucket. +// + +use serde_json::Value; +use tantivy::aggregation::agg_req::{ + Aggregation, Aggregations, BucketAggregation, BucketAggregationType, MetricAggregation, + RangeAggregation, +}; +use tantivy::aggregation::agg_result::AggregationResults; +use tantivy::aggregation::metric::AverageAggregation; +use tantivy::aggregation::AggregationCollector; +use tantivy::query::TermQuery; +use tantivy::schema::{self, Cardinality, IndexRecordOption, Schema, TextFieldIndexing}; +use tantivy::{doc, Index, Term}; + +fn main() -> tantivy::Result<()> { + let mut schema_builder = Schema::builder(); + let text_fieldtype = schema::TextOptions::default() + .set_indexing_options( + TextFieldIndexing::default() + .set_tokenizer("default") + .set_index_option(IndexRecordOption::WithFreqs), + ) + .set_stored(); + let text_field = schema_builder.add_text_field("text", text_fieldtype); + let score_fieldtype = crate::schema::IntOptions::default().set_fast(Cardinality::SingleValue); + let highscore_field = schema_builder.add_f64_field("highscore", score_fieldtype.clone()); + let price_field = schema_builder.add_f64_field("price", score_fieldtype.clone()); + + let schema = schema_builder.build(); + + // # Indexing documents + // + // Lets index a bunch of documents for this example. + let index = Index::create_in_ram(schema); + + let mut index_writer = index.writer(50_000_000)?; + // writing the segment + index_writer.add_document(doc!( + text_field => "cool", + highscore_field => 1f64, + price_field => 0f64, + ))?; + index_writer.add_document(doc!( + text_field => "cool", + highscore_field => 3f64, + price_field => 1f64, + ))?; + index_writer.add_document(doc!( + text_field => "cool", + highscore_field => 5f64, + price_field => 1f64, + ))?; + index_writer.add_document(doc!( + text_field => "nohit", + highscore_field => 6f64, + price_field => 2f64, + ))?; + index_writer.add_document(doc!( + text_field => "cool", + highscore_field => 7f64, + price_field => 2f64, + ))?; + index_writer.commit()?; + index_writer.add_document(doc!( + text_field => "cool", + highscore_field => 11f64, + price_field => 10f64, + ))?; + index_writer.add_document(doc!( + text_field => "cool", + highscore_field => 14f64, + price_field => 15f64, + ))?; + + index_writer.add_document(doc!( + text_field => "cool", + highscore_field => 15f64, + price_field => 20f64, + ))?; + + index_writer.commit()?; + + let reader = index.reader()?; + let text_field = reader.searcher().schema().get_field("text").unwrap(); + + let term_query = TermQuery::new( + Term::from_field_text(text_field, "cool"), + IndexRecordOption::Basic, + ); + + let sub_agg_req_1: Aggregations = vec![( + "average_price".to_string(), + Aggregation::Metric(MetricAggregation::Average( + AverageAggregation::from_field_name("price".to_string()), + )), + )] + .into_iter() + .collect(); + + let agg_req_1: Aggregations = vec![( + "score_ranges".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Range(RangeAggregation { + field: "highscore".to_string(), + ranges: vec![ + (-1f64..9f64).into(), + (9f64..14f64).into(), + (14f64..20f64).into(), + ], + }), + sub_aggregation: sub_agg_req_1.clone(), + }), + )] + .into_iter() + .collect(); + + let collector = AggregationCollector::from_aggs(agg_req_1); + + let searcher = reader.searcher(); + let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap(); + + let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?; + println!("{}", serde_json::to_string_pretty(&res)?); + + Ok(()) +} diff --git a/src/aggregation/README.md b/src/aggregation/README.md new file mode 100644 index 000000000..938006962 --- /dev/null +++ b/src/aggregation/README.md @@ -0,0 +1,36 @@ +# Contributing + +When adding new bucket aggregation make sure to extend the "test_aggregation_flushing" test for at least 2 levels. + + + +# Code Organization + +Tantivy's aggregations have been designed to mimic the +[aggregations of elasticsearch](https://www.elastic.co/guide/en/elasticsearch/reference/current/search-aggregations.html). + +The code is organized in submodules: + +##bucket +Contains all bucket aggregations, like range aggregation. These bucket aggregations group documents into buckets and can contain sub-aggegations. + +##metric +Contains all metric aggregations, like average aggregation. Metric aggregations do not have sub aggregations. + +#### agg_req +agg_req contains the users aggregation request. Deserialization from json is compatible with elasticsearch aggregation requests. + +#### agg_req_with_accessor +agg_req_with_accessor contains the users aggregation request enriched with fast field accessors etc, which are +used during collection. + +#### segment_agg_result +segment_agg_result contains the aggregation result tree, which is used for collection of a segment. +The tree from agg_req_with_accessor is passed during collection. + +#### intermediate_agg_result +intermediate_agg_result contains the aggregation tree for merging with other trees. + +#### agg_result +agg_result contains the final aggregation tree. + diff --git a/src/aggregation/agg_req.rs b/src/aggregation/agg_req.rs new file mode 100644 index 000000000..534c738d2 --- /dev/null +++ b/src/aggregation/agg_req.rs @@ -0,0 +1,169 @@ +//! Contains the aggregation request tree. Used to build an +//! [AggregationCollector](super::AggregationCollector). +//! +//! [Aggregations] is the top level entry point to create a request, which is a `HashMap`. +//! Requests are compatible with the json format of elasticsearch. +//! +//! # Example +//! +//! ``` +//! use tantivy::aggregation::bucket::RangeAggregation; +//! use tantivy::aggregation::agg_req::BucketAggregationType; +//! use tantivy::aggregation::agg_req::{Aggregation, Aggregations}; +//! use tantivy::aggregation::agg_req::BucketAggregation; +//! let agg_req1: Aggregations = vec![ +//! ( +//! "range".to_string(), +//! Aggregation::Bucket(BucketAggregation { +//! bucket_agg: BucketAggregationType::Range(RangeAggregation{ +//! field: "score".to_string(), +//! ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()], +//! }), +//! sub_aggregation: Default::default(), +//! }), +//! ), +//! ] +//! .into_iter() +//! .collect(); +//! +//! let elasticsearch_compatible_json_req = r#" +//! { +//! "range": { +//! "range": { +//! "field": "score", +//! "ranges": [ +//! { "from": 3.0, "to": 7.0 }, +//! { "from": 7.0, "to": 20.0 } +//! ] +//! } +//! } +//! }"#; +//! let agg_req2: Aggregations = serde_json::from_str(elasticsearch_compatible_json_req).unwrap(); +//! assert_eq!(agg_req1, agg_req2); +//! ``` + +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; + +pub use super::bucket::RangeAggregation; +use super::metric::{AverageAggregation, StatsAggregation}; + +/// The top-level aggregation request structure, which contains [Aggregation] and their user defined +/// names. +/// +/// The key is the user defined name of the aggregation. +pub type Aggregations = HashMap; + +/// Aggregation request of [BucketAggregation] or [MetricAggregation]. +/// +/// An aggregation is either a bucket or a metric. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[serde(untagged)] +pub enum Aggregation { + /// Bucket aggregation, see [BucketAggregation] for details. + Bucket(BucketAggregation), + /// Metric aggregation, see [MetricAggregation] for details. + Metric(MetricAggregation), +} + +/// BucketAggregations create buckets of documents. Each bucket is associated with a rule which +/// determines whether or not a document in the falls into it. In other words, the buckets +/// effectively define document sets. Buckets are not necessarily disjunct, therefore a document can +/// fall into multiple buckets. In addition to the buckets themselves, the bucket aggregations also +/// compute and return the number of documents for each bucket. Bucket aggregations, as opposed to +/// metric aggregations, can hold sub-aggregations. These sub-aggregations will be aggregated for +/// the buckets created by their "parent" bucket aggregation. There are different bucket +/// aggregators, each with a different "bucketing" strategy. Some define a single bucket, some +/// define fixed number of multiple buckets, and others dynamically create the buckets during the +/// aggregation process. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct BucketAggregation { + /// Bucket aggregation strategy to group documents. + #[serde(flatten)] + pub bucket_agg: BucketAggregationType, + /// The sub_aggregations in the buckets. Each bucket will aggregate on the document set in the + /// bucket. + #[serde(rename = "aggs")] + #[serde(default)] + #[serde(skip_serializing_if = "Aggregations::is_empty")] + pub sub_aggregation: Aggregations, +} + +/// The bucket aggregation types. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub enum BucketAggregationType { + /// Put data into buckets of user-defined ranges. + #[serde(rename = "range")] + Range(RangeAggregation), +} + +/// The aggregations in this family compute metrics based on values extracted +/// from the documents that are being aggregated. Values are extracted from the fast field of +/// the document. + +/// Some aggregations output a single numeric metric (e.g. Average) and are called +/// single-value numeric metrics aggregation, others generate multiple metrics (e.g. Stats) and are +/// called multi-value numeric metrics aggregation. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub enum MetricAggregation { + /// Calculates the average. + #[serde(rename = "avg")] + Average(AverageAggregation), + /// Calculates stats sum, average, min, max, standard_deviation on a field. + #[serde(rename = "stats")] + Stats(StatsAggregation), +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn serialize_to_json_test() { + let agg_req1: Aggregations = vec![( + "range".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Range(RangeAggregation { + field: "score".to_string(), + ranges: vec![ + (f64::MIN..3f64).into(), + (3f64..7f64).into(), + (7f64..20f64).into(), + (20f64..f64::MAX).into(), + ], + }), + sub_aggregation: Default::default(), + }), + )] + .into_iter() + .collect(); + + let elasticsearch_compatible_json_req = r#"{ + "range": { + "range": { + "field": "score", + "ranges": [ + { + "to": 3.0 + }, + { + "from": 3.0, + "to": 7.0 + }, + { + "from": 7.0, + "to": 20.0 + }, + { + "from": 20.0 + } + ] + } + } +}"#; + let agg_req2: String = serde_json::to_string_pretty(&agg_req1).unwrap(); + assert_eq!(agg_req2, elasticsearch_compatible_json_req); + } +} diff --git a/src/aggregation/agg_req_with_accessor.rs b/src/aggregation/agg_req_with_accessor.rs new file mode 100644 index 000000000..b530176dd --- /dev/null +++ b/src/aggregation/agg_req_with_accessor.rs @@ -0,0 +1,140 @@ +//! This will enhance the request tree with access to the fastfield and metadata. + +use super::agg_req::{Aggregation, Aggregations, BucketAggregationType, MetricAggregation}; +use super::bucket::RangeAggregation; +use super::metric::{AverageAggregation, StatsAggregation}; +use super::VecWithNames; +use crate::fastfield::DynamicFastFieldReader; +use crate::schema::Type; +use crate::{SegmentReader, TantivyError}; + +#[derive(Clone, Default)] +pub(crate) struct AggregationsWithAccessor { + pub metrics: VecWithNames, + pub buckets: VecWithNames, +} + +impl AggregationsWithAccessor { + fn from_data( + metrics: VecWithNames, + buckets: VecWithNames, + ) -> Self { + Self { metrics, buckets } + } + + pub fn is_empty(&self) -> bool { + self.metrics.is_empty() && self.buckets.is_empty() + } +} + +#[derive(Clone)] +pub struct BucketAggregationWithAccessor { + /// In general there can be buckets without fast field access, e.g. buckets that are created + /// based on search terms. So eventually this needs to be Option or moved. + pub(crate) accessor: DynamicFastFieldReader, + pub(crate) field_type: Type, + pub(crate) bucket_agg: BucketAggregationType, + pub(crate) sub_aggregation: AggregationsWithAccessor, +} + +impl BucketAggregationWithAccessor { + fn from_bucket( + bucket: &BucketAggregationType, + sub_aggregation: &Aggregations, + reader: &SegmentReader, + ) -> crate::Result { + let (accessor, field_type) = match &bucket { + BucketAggregationType::Range(RangeAggregation { + field: field_name, + ranges: _, + }) => get_ff_reader_and_validate(reader, field_name)?, + }; + let sub_aggregation = sub_aggregation.clone(); + Ok(BucketAggregationWithAccessor { + accessor, + field_type, + sub_aggregation: get_aggregations_with_accessor(&sub_aggregation, reader)?, + bucket_agg: bucket.clone(), + }) + } +} + +/// Contains the metric request and the fast field accessor. +#[derive(Clone)] +pub struct MetricAggregationWithAccessor { + pub metric: MetricAggregation, + pub field_type: Type, + pub accessor: DynamicFastFieldReader, +} + +impl MetricAggregationWithAccessor { + fn from_metric( + metric: &MetricAggregation, + reader: &SegmentReader, + ) -> crate::Result { + match &metric { + MetricAggregation::Average(AverageAggregation { field: field_name }) + | MetricAggregation::Stats(StatsAggregation { field: field_name }) => { + let (accessor, field_type) = get_ff_reader_and_validate(reader, field_name)?; + + Ok(MetricAggregationWithAccessor { + accessor, + field_type, + metric: metric.clone(), + }) + } + } + } +} + +pub(crate) fn get_aggregations_with_accessor( + aggs: &Aggregations, + reader: &SegmentReader, +) -> crate::Result { + let mut metrics = vec![]; + let mut buckets = vec![]; + for (key, agg) in aggs.iter() { + match agg { + Aggregation::Bucket(bucket) => buckets.push(( + key.to_string(), + BucketAggregationWithAccessor::from_bucket( + &bucket.bucket_agg, + &bucket.sub_aggregation, + reader, + )?, + )), + Aggregation::Metric(metric) => metrics.push(( + key.to_string(), + MetricAggregationWithAccessor::from_metric(metric, reader)?, + )), + } + } + Ok(AggregationsWithAccessor::from_data( + VecWithNames::from_entries(metrics), + VecWithNames::from_entries(buckets), + )) +} + +fn get_ff_reader_and_validate( + reader: &SegmentReader, + field_name: &str, +) -> crate::Result<(DynamicFastFieldReader, Type)> { + let field = reader + .schema() + .get_field(field_name) + .ok_or_else(|| TantivyError::FieldNotFound(field_name.to_string()))?; + let field_type = reader.schema().get_field_entry(field).field_type(); + if field_type.value_type() != Type::I64 + && field_type.value_type() != Type::U64 + && field_type.value_type() != Type::F64 + { + return Err(TantivyError::InvalidArgument(format!( + "Invalid field type in aggregation {:?}, only f64, u64, i64 is supported", + field_type.value_type() + ))); + } + let ff_fields = reader.fast_fields(); + ff_fields + .u64_lenient(field) + .map(|field| (field, field_type.value_type())) +} diff --git a/src/aggregation/agg_result.rs b/src/aggregation/agg_result.rs new file mode 100644 index 000000000..584ec7f21 --- /dev/null +++ b/src/aggregation/agg_result.rs @@ -0,0 +1,142 @@ +//! Contains the final aggregation tree. +//! This tree can be converted via the `into()` method from `IntermediateAggregationResults`. +//! This conversion computes the final result. For example: The intermediate result contains +//! intermediate average results, which is the sum and the number of values. The actual average is +//! calculated on the step from intermediate to final aggregation result tree. + +use std::cmp::Ordering; +use std::collections::HashMap; + +use itertools::Itertools; +use serde::{Deserialize, Serialize}; + +use super::intermediate_agg_result::{ + IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, + IntermediateMetricResult, IntermediateRangeBucketEntry, +}; +use super::metric::{SingleMetricResult, Stats}; +use super::Key; + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +/// The final aggegation result. +pub struct AggregationResults(pub HashMap); + +impl From for AggregationResults { + fn from(tree: IntermediateAggregationResults) -> Self { + Self( + tree.0 + .into_iter() + .map(|(key, agg)| (key, agg.into())) + .collect(), + ) + } +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[serde(untagged)] +/// An aggregation is either a bucket or a metric. +pub enum AggregationResult { + /// Bucket result variant. + BucketResult(BucketResult), + /// Metric result variant. + MetricResult(MetricResult), +} +impl From for AggregationResult { + fn from(tree: IntermediateAggregationResult) -> Self { + match tree { + IntermediateAggregationResult::Bucket(bucket) => { + AggregationResult::BucketResult(bucket.into()) + } + IntermediateAggregationResult::Metric(metric) => { + AggregationResult::MetricResult(metric.into()) + } + } + } +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[serde(untagged)] +/// MetricResult +pub enum MetricResult { + /// Average metric result. + Average(SingleMetricResult), + /// Stats metric result. + Stats(Stats), +} + +impl From for MetricResult { + fn from(metric: IntermediateMetricResult) -> Self { + match metric { + IntermediateMetricResult::Average(avg_data) => { + MetricResult::Average(avg_data.finalize().into()) + } + IntermediateMetricResult::Stats(intermediate_stats) => { + MetricResult::Stats(intermediate_stats.finalize()) + } + } + } +} + +/// BucketEntry holds bucket aggregation result types. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[serde(untagged)] +pub enum BucketResult { + /// This is the default entry for a bucket, which contains a key, count, and optionally + /// sub_aggregations. + Range { + /// The range buckets sorted by range. + buckets: Vec, + }, +} + +impl From for BucketResult { + fn from(result: IntermediateBucketResult) -> Self { + match result { + IntermediateBucketResult::Range(range_map) => { + let mut buckets: Vec = range_map + .into_iter() + .map(|(_, bucket)| bucket.into()) + .collect_vec(); + + buckets.sort_by(|a, b| { + a.from + .unwrap_or(f64::MIN) + .partial_cmp(&b.from.unwrap_or(f64::MIN)) + .unwrap_or(Ordering::Equal) + }); + BucketResult::Range { buckets } + } + } + } +} + +/// This is the range entry for a bucket, which contains a key, count, and optionally +/// sub_aggregations. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct RangeBucketEntry { + /// The identifier of the bucket. + pub key: Key, + /// Number of documents in the bucket. + pub doc_count: u64, + #[serde(flatten)] + /// sub-aggregations in this bucket. + pub sub_aggregation: AggregationResults, + /// The from range of the bucket. Equals f64::MIN when None. + #[serde(skip_serializing_if = "Option::is_none")] + pub from: Option, + /// The to range of the bucket. Equals f64::MAX when None. + #[serde(skip_serializing_if = "Option::is_none")] + pub to: Option, +} + +impl From for RangeBucketEntry { + fn from(entry: IntermediateRangeBucketEntry) -> Self { + RangeBucketEntry { + key: entry.key, + doc_count: entry.doc_count, + sub_aggregation: entry.sub_aggregation.into(), + to: entry.to, + from: entry.from, + } + } +} diff --git a/src/aggregation/bucket/mod.rs b/src/aggregation/bucket/mod.rs new file mode 100644 index 000000000..22ceea0aa --- /dev/null +++ b/src/aggregation/bucket/mod.rs @@ -0,0 +1,10 @@ +//! Module for all bucket aggregations. +//! +//! Results of final buckets are [BucketEntry](super::agg_result::BucketEntry). +//! Results of intermediate buckets are +//! [IntermediateBucketEntry](super::intermediate_agg_result::IntermediateBucketEntry) + +mod range; + +pub use range::RangeAggregation; +pub(crate) use range::SegmentRangeCollector; diff --git a/src/aggregation/bucket/range.rs b/src/aggregation/bucket/range.rs new file mode 100644 index 000000000..e4ff6ee76 --- /dev/null +++ b/src/aggregation/bucket/range.rs @@ -0,0 +1,536 @@ +use std::ops::Range; + +use itertools::Itertools; +use serde::{Deserialize, Serialize}; + +use crate::aggregation::agg_req_with_accessor::{ + AggregationsWithAccessor, BucketAggregationWithAccessor, +}; +use crate::aggregation::intermediate_agg_result::IntermediateBucketResult; +use crate::aggregation::segment_agg_result::{ + SegmentAggregationResultsCollector, SegmentRangeBucketEntry, +}; +use crate::aggregation::{f64_from_fastfield_u64, f64_to_fastfield_u64, Key}; +use crate::fastfield::FastFieldReader; +use crate::schema::Type; +use crate::{DocId, TantivyError}; + +/// Provide user-defined buckets to aggregate on. +/// Two special buckets will automatically be created to cover the whole range of values. +/// The provided buckets have to be continous. +/// During the aggregation, the values extracted from the fast_field `field_name` will be checked +/// against each bucket range. Note that this aggregation includes the from value and excludes the +/// to value for each range. +/// +/// Result type is [BucketResult](crate::aggregation::agg_result::BucketResult) with +/// [BucketEntryKeyCount](crate::aggregation::agg_result::BucketEntryKeyCount) on the +/// AggregationCollector. +/// +/// Result type is +/// [crate::aggregation::intermediate_agg_result::IntermediateBucketResult] with +/// [crate::aggregation::intermediate_agg_result::IntermediateBucketEntryKeyCount] on the +/// DistributedAggregationCollector. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct RangeAggregation { + /// The field to aggregate on. + pub field: String, + /// Note that this aggregation includes the from value and excludes the to value for each + /// range. Extra buckets will be created until the first to, and last from, if necessary. + pub ranges: Vec, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct RangeAggregationRange { + #[serde(skip_serializing_if = "Option::is_none", default)] + pub from: Option, + #[serde(skip_serializing_if = "Option::is_none", default)] + pub to: Option, +} + +impl From> for RangeAggregationRange { + fn from(range: Range) -> Self { + let from = if range.start == f64::MIN { + None + } else { + Some(range.start) + }; + let to = if range.end == f64::MAX { + None + } else { + Some(range.end) + }; + RangeAggregationRange { from, to } + } +} + +#[derive(Clone, Debug, PartialEq)] +pub struct SegmentRangeAndBucketEntry { + range: Range, + bucket: SegmentRangeBucketEntry, +} + +/// The collector puts values from the fast field into the correct buckets and does a conversion to +/// the correct datatype. +#[derive(Clone, Debug, PartialEq)] +pub struct SegmentRangeCollector { + /// The buckets containing the aggregation data. + buckets: Vec, + field_type: Type, +} + +impl SegmentRangeCollector { + pub fn into_intermediate_bucket_result(self) -> IntermediateBucketResult { + let field_type = self.field_type; + + let buckets = self + .buckets + .into_iter() + .map(move |range_bucket| { + ( + range_to_key(&range_bucket.range, &field_type), + range_bucket.bucket.into(), + ) + }) + .collect(); + + IntermediateBucketResult::Range(buckets) + } + + pub(crate) fn from_req( + req: &RangeAggregation, + sub_aggregation: &AggregationsWithAccessor, + field_type: Type, + ) -> crate::Result { + // The range input on the request is f64. + // We need to convert to u64 ranges, because we read the values as u64. + // The mapping from the conversion is monotonic so ordering is preserved. + let buckets = extend_validate_ranges(&req.ranges, &field_type)? + .iter() + .map(|range| { + let to = if range.end == u64::MAX { + None + } else { + Some(f64_from_fastfield_u64(range.end, &field_type)) + }; + let from = if range.start == u64::MIN { + None + } else { + Some(f64_from_fastfield_u64(range.start, &field_type)) + }; + let sub_aggregation = if sub_aggregation.is_empty() { + None + } else { + Some(SegmentAggregationResultsCollector::from_req( + sub_aggregation, + )?) + }; + Ok(SegmentRangeAndBucketEntry { + range: range.clone(), + bucket: SegmentRangeBucketEntry { + key: range_to_key(range, &field_type), + doc_count: 0, + sub_aggregation, + from, + to, + }, + }) + }) + .collect::>()?; + + Ok(SegmentRangeCollector { + buckets, + field_type, + }) + } + + #[inline] + pub(crate) fn collect_block( + &mut self, + doc: &[DocId], + bucket_with_accessor: &BucketAggregationWithAccessor, + force_flush: bool, + ) { + let mut iter = doc.chunks_exact(4); + for docs in iter.by_ref() { + let val1 = bucket_with_accessor.accessor.get(docs[0]); + let val2 = bucket_with_accessor.accessor.get(docs[1]); + let val3 = bucket_with_accessor.accessor.get(docs[2]); + let val4 = bucket_with_accessor.accessor.get(docs[3]); + let bucket_pos1 = self.get_bucket_pos(val1); + let bucket_pos2 = self.get_bucket_pos(val2); + let bucket_pos3 = self.get_bucket_pos(val3); + let bucket_pos4 = self.get_bucket_pos(val4); + + self.increment_bucket(bucket_pos1, docs[0], &bucket_with_accessor.sub_aggregation); + self.increment_bucket(bucket_pos2, docs[1], &bucket_with_accessor.sub_aggregation); + self.increment_bucket(bucket_pos3, docs[2], &bucket_with_accessor.sub_aggregation); + self.increment_bucket(bucket_pos4, docs[3], &bucket_with_accessor.sub_aggregation); + } + for doc in iter.remainder() { + let val = bucket_with_accessor.accessor.get(*doc); + let bucket_pos = self.get_bucket_pos(val); + self.increment_bucket(bucket_pos, *doc, &bucket_with_accessor.sub_aggregation); + } + if force_flush { + for bucket in &mut self.buckets { + if let Some(sub_aggregation) = &mut bucket.bucket.sub_aggregation { + sub_aggregation + .flush_staged_docs(&bucket_with_accessor.sub_aggregation, force_flush); + } + } + } + } + + #[inline] + fn increment_bucket( + &mut self, + bucket_pos: usize, + doc: DocId, + bucket_with_accessor: &AggregationsWithAccessor, + ) { + let bucket = &mut self.buckets[bucket_pos]; + + bucket.bucket.doc_count += 1; + if let Some(sub_aggregation) = &mut bucket.bucket.sub_aggregation { + sub_aggregation.collect(doc, bucket_with_accessor); + } + } + + #[inline] + fn get_bucket_pos(&self, val: u64) -> usize { + let pos = self + .buckets + .binary_search_by_key(&val, |probe| probe.range.start) + .unwrap_or_else(|pos| pos - 1); + debug_assert!(self.buckets[pos].range.contains(&val)); + pos + } +} + +/// Converts the user provided f64 range value to fast field value space. +/// +/// Internally fast field values are always stored as u64. +/// If the fast field has u64 [1,2,5], these values are stored as is in the fast field. +/// A fast field with f64 [1.0, 2.0, 5.0] is converted to u64 space, using a +/// monotonic mapping function, so the order is preserved. +/// +/// Consequently, a f64 user range 1.0..3.0 needs to be converted to fast field value space using +/// the same monotonic mapping function, so that the provided ranges contain the u64 values in the +/// fast field. +/// The alternative would be that every value read would be converted to the f64 range, but that is +/// more computational expensive when many documents are hit. +fn to_u64_range(range: &RangeAggregationRange, field_type: &Type) -> Range { + range + .from + .map(|from| f64_to_fastfield_u64(from, field_type)) + .unwrap_or(u64::MIN) + ..range + .to + .map(|to| f64_to_fastfield_u64(to, field_type)) + .unwrap_or(u64::MAX) +} + +/// Extends the provided buckets to contain the whole value range, by inserting buckets at the +/// beginning and end. +fn extend_validate_ranges( + buckets: &[RangeAggregationRange], + field_type: &Type, +) -> crate::Result>> { + let mut converted_buckets = buckets + .iter() + .map(|range| to_u64_range(range, field_type)) + .collect_vec(); + + converted_buckets.sort_by_key(|bucket| bucket.start); + if converted_buckets[0].start != u64::MIN { + converted_buckets.insert(0, u64::MIN..converted_buckets[0].start); + } + + if converted_buckets[converted_buckets.len() - 1].end != u64::MAX { + converted_buckets.push(converted_buckets[converted_buckets.len() - 1].end..u64::MAX); + } + + // fill up holes in the ranges + let find_hole = |converted_buckets: &[Range]| { + for (pos, ranges) in converted_buckets.windows(2).enumerate() { + if ranges[0].end > ranges[1].start { + return Err(TantivyError::InvalidArgument(format!( + "Overlapping ranges not supported range {:?}, range+1 {:?}", + ranges[0], ranges[1] + ))); + } + if ranges[0].end != ranges[1].start { + return Ok(Some(pos)); + } + } + Ok(None) + }; + + while let Some(hole_pos) = find_hole(&converted_buckets)? { + let new_range = converted_buckets[hole_pos].end..converted_buckets[hole_pos + 1].start; + converted_buckets.insert(hole_pos + 1, new_range); + } + + Ok(converted_buckets) +} + +pub fn range_to_string(range: &Range, field_type: &Type) -> String { + // is_start is there for malformed requests, e.g. ig the user passes the range u64::MIN..0.0, + // it should be rendererd as "*-0" and not "*-*" + let to_str = |val: u64, is_start: bool| { + if (is_start && val == u64::MIN) || (!is_start && val == u64::MAX) { + "*".to_string() + } else { + f64_from_fastfield_u64(val, field_type).to_string() + } + }; + + format!("{}-{}", to_str(range.start, true), to_str(range.end, false)) +} + +pub fn range_to_key(range: &Range, field_type: &Type) -> Key { + Key::Str(range_to_string(range, field_type)) +} + +#[cfg(test)] +mod tests { + + use serde_json::Value; + + use super::*; + use crate::aggregation::agg_req::{ + Aggregation, Aggregations, BucketAggregation, BucketAggregationType, + }; + use crate::aggregation::tests::get_test_index_with_num_docs; + use crate::aggregation::AggregationCollector; + use crate::fastfield::FastValue; + use crate::query::AllQuery; + + pub fn get_collector_from_ranges( + ranges: Vec, + field_type: Type, + ) -> SegmentRangeCollector { + let req = RangeAggregation { + field: "dummy".to_string(), + ranges, + }; + + SegmentRangeCollector::from_req(&req, &Default::default(), field_type).unwrap() + } + + #[test] + fn range_fraction_test() -> crate::Result<()> { + let index = get_test_index_with_num_docs(false, 100)?; + + let agg_req: Aggregations = vec![( + "range".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Range(RangeAggregation { + field: "fraction_f64".to_string(), + ranges: vec![(0f64..0.1f64).into(), (0.1f64..0.2f64).into()], + }), + sub_aggregation: Default::default(), + }), + )] + .into_iter() + .collect(); + + let collector = AggregationCollector::from_aggs(agg_req); + + let reader = index.reader()?; + let searcher = reader.searcher(); + let agg_res = searcher.search(&AllQuery, &collector).unwrap(); + + let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?; + + assert_eq!(res["range"]["buckets"][0]["key"], "*-0"); + assert_eq!(res["range"]["buckets"][0]["doc_count"], 0); + assert_eq!(res["range"]["buckets"][1]["key"], "0-0.1"); + assert_eq!(res["range"]["buckets"][1]["doc_count"], 10); + assert_eq!(res["range"]["buckets"][2]["key"], "0.1-0.2"); + assert_eq!(res["range"]["buckets"][2]["doc_count"], 10); + assert_eq!(res["range"]["buckets"][3]["key"], "0.2-*"); + assert_eq!(res["range"]["buckets"][3]["doc_count"], 80); + + Ok(()) + } + + #[test] + fn bucket_test_extend_range_hole() { + let buckets = vec![(10f64..20f64).into(), (30f64..40f64).into()]; + let collector = get_collector_from_ranges(buckets, Type::F64); + + let buckets = collector.buckets; + assert_eq!(buckets[0].range.start, u64::MIN); + assert_eq!(buckets[0].range.end, 10f64.to_u64()); + assert_eq!(buckets[1].range.start, 10f64.to_u64()); + assert_eq!(buckets[1].range.end, 20f64.to_u64()); + // Added bucket to fill hole + assert_eq!(buckets[2].range.start, 20f64.to_u64()); + assert_eq!(buckets[2].range.end, 30f64.to_u64()); + assert_eq!(buckets[3].range.start, 30f64.to_u64()); + assert_eq!(buckets[3].range.end, 40f64.to_u64()); + } + + #[test] + fn bucket_test_range_conversion_special_case() { + // the monotonic conversion between f64 and u64, does not map f64::MIN.to_u64() == + // u64::MIN, but the into trait converts f64::MIN/MAX to None + let buckets = vec![ + (f64::MIN..10f64).into(), + (10f64..20f64).into(), + (20f64..f64::MAX).into(), + ]; + let collector = get_collector_from_ranges(buckets, Type::F64); + + let buckets = collector.buckets; + assert_eq!(buckets[0].range.start, u64::MIN); + assert_eq!(buckets[0].range.end, 10f64.to_u64()); + assert_eq!(buckets[1].range.start, 10f64.to_u64()); + assert_eq!(buckets[1].range.end, 20f64.to_u64()); + assert_eq!(buckets[2].range.start, 20f64.to_u64()); + assert_eq!(buckets[2].range.end, u64::MAX); + assert_eq!(buckets.len(), 3); + } + + #[test] + fn bucket_range_test_negative_vals() { + let buckets = vec![(-10f64..-1f64).into()]; + let collector = get_collector_from_ranges(buckets, Type::F64); + + let buckets = collector.buckets; + assert_eq!(&buckets[0].bucket.key.to_string(), "*--10"); + assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "-1-*"); + } + #[test] + fn bucket_range_test_positive_vals() { + let buckets = vec![(0f64..10f64).into()]; + let collector = get_collector_from_ranges(buckets, Type::F64); + + let buckets = collector.buckets; + assert_eq!(&buckets[0].bucket.key.to_string(), "*-0"); + assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "10-*"); + } + + #[test] + fn range_binary_search_test_u64() { + let check_ranges = |ranges: Vec| { + let collector = get_collector_from_ranges(ranges, Type::U64); + let search = |val: u64| collector.get_bucket_pos(val); + + assert_eq!(search(u64::MIN), 0); + assert_eq!(search(9), 0); + assert_eq!(search(10), 1); + assert_eq!(search(11), 1); + assert_eq!(search(99), 1); + assert_eq!(search(100), 2); + assert_eq!(search(u64::MAX - 1), 2); // Since the end range is never included, the max + // value + }; + + let ranges = vec![(10.0..100.0).into()]; + check_ranges(ranges); + + let ranges = vec![ + RangeAggregationRange { + to: Some(10.0), + from: None, + }, + (10.0..100.0).into(), + ]; + check_ranges(ranges); + + let ranges = vec![ + RangeAggregationRange { + to: Some(10.0), + from: None, + }, + (10.0..100.0).into(), + RangeAggregationRange { + to: None, + from: Some(100.0), + }, + ]; + check_ranges(ranges); + } + + #[test] + fn range_binary_search_test_f64() { + let ranges = vec![ + //(f64::MIN..10.0).into(), + (10.0..100.0).into(), + //(100.0..f64::MAX).into(), + ]; + + let collector = get_collector_from_ranges(ranges, Type::F64); + let search = |val: u64| collector.get_bucket_pos(val); + + assert_eq!(search(u64::MIN), 0); + assert_eq!(search(9f64.to_u64()), 0); + assert_eq!(search(10f64.to_u64()), 1); + assert_eq!(search(11f64.to_u64()), 1); + assert_eq!(search(99f64.to_u64()), 1); + assert_eq!(search(100f64.to_u64()), 2); + assert_eq!(search(u64::MAX - 1), 2); // Since the end range is never included, + // the max value + } +} + +#[cfg(all(test, feature = "unstable"))] +mod bench { + + use rand::seq::SliceRandom; + use rand::thread_rng; + + use super::*; + use crate::aggregation::bucket::range::tests::get_collector_from_ranges; + + const TOTAL_DOCS: u64 = 1_000_000u64; + const NUM_DOCS: u64 = 50_000u64; + + fn get_collector_with_buckets(num_buckets: u64, num_docs: u64) -> SegmentRangeCollector { + let bucket_size = num_docs / num_buckets; + let mut buckets: Vec = vec![]; + for i in 0..num_buckets { + let bucket_start = (i * bucket_size) as f64; + buckets.push((bucket_start..bucket_start + bucket_size as f64).into()) + } + + get_collector_from_ranges(buckets, Type::U64) + } + + fn get_rand_docs(total_docs: u64, num_docs_returned: u64) -> Vec { + let mut rng = thread_rng(); + + let all_docs = (0..total_docs - 1).collect_vec(); + let mut vals = all_docs + .as_slice() + .choose_multiple(&mut rng, num_docs_returned as usize) + .cloned() + .collect_vec(); + vals.sort(); + vals + } + + fn bench_range_binary_search(b: &mut test::Bencher, num_buckets: u64) { + let collector = get_collector_with_buckets(num_buckets, TOTAL_DOCS); + let vals = get_rand_docs(TOTAL_DOCS, NUM_DOCS); + b.iter(|| { + let mut bucket_pos = 0; + for val in &vals { + bucket_pos = collector.get_bucket_pos(*val); + } + bucket_pos + }) + } + + #[bench] + fn bench_range_100_buckets(b: &mut test::Bencher) { + bench_range_binary_search(b, 100) + } + + #[bench] + fn bench_range_10_buckets(b: &mut test::Bencher) { + bench_range_binary_search(b, 10) + } +} diff --git a/src/aggregation/collector.rs b/src/aggregation/collector.rs new file mode 100644 index 000000000..2c3d2d62e --- /dev/null +++ b/src/aggregation/collector.rs @@ -0,0 +1,135 @@ +use super::agg_req::Aggregations; +use super::agg_req_with_accessor::AggregationsWithAccessor; +use super::agg_result::AggregationResults; +use super::intermediate_agg_result::IntermediateAggregationResults; +use super::segment_agg_result::SegmentAggregationResultsCollector; +use crate::aggregation::agg_req_with_accessor::get_aggregations_with_accessor; +use crate::collector::{Collector, SegmentCollector}; +use crate::TantivyError; + +/// Collector for aggregations. +/// +/// The collector collects all aggregations by the underlying aggregation request. +pub struct AggregationCollector { + agg: Aggregations, +} + +impl AggregationCollector { + /// Create collector from aggregation request. + pub fn from_aggs(agg: Aggregations) -> Self { + Self { agg } + } +} + +/// Collector for distributed aggregations. +/// +/// The collector collects all aggregations by the underlying aggregation request. +/// +/// # Purpose +/// AggregationCollector returns `IntermediateAggregationResults` and not the final +/// `AggregationResults`, so that results from differenct indices can be merged and then converted +/// into the final `AggregationResults` via the `into()` method. +pub struct DistributedAggregationCollector { + agg: Aggregations, +} + +impl DistributedAggregationCollector { + /// Create collector from aggregation request. + pub fn from_aggs(agg: Aggregations) -> Self { + Self { agg } + } +} + +impl Collector for DistributedAggregationCollector { + type Fruit = IntermediateAggregationResults; + + type Child = AggregationSegmentCollector; + + fn for_segment( + &self, + _segment_local_id: crate::SegmentOrdinal, + reader: &crate::SegmentReader, + ) -> crate::Result { + let aggs_with_accessor = get_aggregations_with_accessor(&self.agg, reader)?; + let result = SegmentAggregationResultsCollector::from_req(&aggs_with_accessor)?; + Ok(AggregationSegmentCollector { + aggs: aggs_with_accessor, + result, + }) + } + + fn requires_scoring(&self) -> bool { + false + } + + fn merge_fruits( + &self, + segment_fruits: Vec<::Fruit>, + ) -> crate::Result { + merge_fruits(segment_fruits) + } +} + +impl Collector for AggregationCollector { + type Fruit = AggregationResults; + + type Child = AggregationSegmentCollector; + + fn for_segment( + &self, + _segment_local_id: crate::SegmentOrdinal, + reader: &crate::SegmentReader, + ) -> crate::Result { + let aggs_with_accessor = get_aggregations_with_accessor(&self.agg, reader)?; + let result = SegmentAggregationResultsCollector::from_req(&aggs_with_accessor)?; + Ok(AggregationSegmentCollector { + aggs: aggs_with_accessor, + result, + }) + } + + fn requires_scoring(&self) -> bool { + false + } + + fn merge_fruits( + &self, + segment_fruits: Vec<::Fruit>, + ) -> crate::Result { + merge_fruits(segment_fruits).map(|res| res.into()) + } +} + +fn merge_fruits( + mut segment_fruits: Vec, +) -> crate::Result { + if let Some(mut fruit) = segment_fruits.pop() { + for next_fruit in segment_fruits { + fruit.merge_fruits(&next_fruit); + } + Ok(fruit) + } else { + Err(TantivyError::InvalidArgument( + "no fruits provided in merge_fruits".to_string(), + )) + } +} + +pub struct AggregationSegmentCollector { + aggs: AggregationsWithAccessor, + result: SegmentAggregationResultsCollector, +} + +impl SegmentCollector for AggregationSegmentCollector { + type Fruit = IntermediateAggregationResults; + + #[inline] + fn collect(&mut self, doc: crate::DocId, _score: crate::Score) { + self.result.collect(doc, &self.aggs); + } + + fn harvest(mut self) -> Self::Fruit { + self.result.flush_staged_docs(&self.aggs, true); + self.result.into() + } +} diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs new file mode 100644 index 000000000..9e457ab23 --- /dev/null +++ b/src/aggregation/intermediate_agg_result.rs @@ -0,0 +1,304 @@ +//! Contains the intermediate aggregation tree, that can be merged. +//! Intermediate aggregation results can be used to merge results between segments or between +//! indices. + +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; + +use super::metric::{IntermediateAverage, IntermediateStats}; +use super::segment_agg_result::{ + SegmentAggregationResultsCollector, SegmentBucketResultCollector, SegmentMetricResultCollector, + SegmentRangeBucketEntry, +}; +use super::{Key, VecWithNames}; + +/// Contains the intermediate aggregation result, which is optimized to be merged with other +/// intermediate results. +#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct IntermediateAggregationResults(pub(crate) VecWithNames); + +impl From for IntermediateAggregationResults { + fn from(tree: SegmentAggregationResultsCollector) -> Self { + let mut data = vec![]; + for (key, bucket) in tree.buckets.into_iter() { + data.push((key, IntermediateAggregationResult::Bucket(bucket.into()))); + } + for (key, metric) in tree.metrics.into_iter() { + data.push((key, IntermediateAggregationResult::Metric(metric.into()))); + } + Self(VecWithNames::from_entries(data)) + } +} + +impl IntermediateAggregationResults { + /// Merge an other intermediate aggregation result into this result. + /// + /// The order of the values need to be the same on both results. This is ensured when the same + /// (key values) are present on the underlying VecWithNames struct. + pub fn merge_fruits(&mut self, other: &IntermediateAggregationResults) { + for (tree_left, tree_right) in self.0.values_mut().zip(other.0.values()) { + tree_left.merge_fruits(tree_right); + } + } +} + +/// An aggregation is either a bucket or a metric. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub enum IntermediateAggregationResult { + /// Bucket variant + Bucket(IntermediateBucketResult), + /// Metric variant + Metric(IntermediateMetricResult), +} + +impl IntermediateAggregationResult { + fn merge_fruits(&mut self, other: &IntermediateAggregationResult) { + match (self, other) { + ( + IntermediateAggregationResult::Bucket(res_left), + IntermediateAggregationResult::Bucket(res_right), + ) => { + res_left.merge_fruits(res_right); + } + ( + IntermediateAggregationResult::Metric(res_left), + IntermediateAggregationResult::Metric(res_right), + ) => { + res_left.merge_fruits(res_right); + } + _ => { + panic!("incompatible types in aggregation tree on merge fruits"); + } + } + } +} + +/// Holds the intermediate data for metric results +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub enum IntermediateMetricResult { + /// Average containing intermediate average data result + Average(IntermediateAverage), + /// AverageData variant + Stats(IntermediateStats), +} + +impl From for IntermediateMetricResult { + fn from(tree: SegmentMetricResultCollector) -> Self { + match tree { + SegmentMetricResultCollector::Average(collector) => { + IntermediateMetricResult::Average(IntermediateAverage::from_collector(collector)) + } + SegmentMetricResultCollector::Stats(collector) => { + IntermediateMetricResult::Stats(collector.stats) + } + } + } +} + +impl IntermediateMetricResult { + fn merge_fruits(&mut self, other: &IntermediateMetricResult) { + match (self, other) { + ( + IntermediateMetricResult::Average(avg_data_left), + IntermediateMetricResult::Average(avg_data_right), + ) => { + avg_data_left.merge_fruits(avg_data_right); + } + ( + IntermediateMetricResult::Stats(stats_left), + IntermediateMetricResult::Stats(stats_right), + ) => { + stats_left.merge_fruits(stats_right); + } + _ => { + panic!("incompatible fruit types in tree {:?}", other); + } + } + } +} + +/// The intermediate bucket results. Internally they can be easily merged via the keys of the +/// buckets. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub enum IntermediateBucketResult { + /// This is the range entry for a bucket, which contains a key, count, from, to, and optionally + /// sub_aggregations. + Range(HashMap), +} + +impl From for IntermediateBucketResult { + fn from(collector: SegmentBucketResultCollector) -> Self { + match collector { + SegmentBucketResultCollector::Range(range) => range.into_intermediate_bucket_result(), + } + } +} + +impl IntermediateBucketResult { + fn merge_fruits(&mut self, other: &IntermediateBucketResult) { + match (self, other) { + ( + IntermediateBucketResult::Range(entries_left), + IntermediateBucketResult::Range(entries_right), + ) => { + for (name, entry_left) in entries_left.iter_mut() { + if let Some(entry_right) = entries_right.get(name) { + entry_left.merge_fruits(entry_right); + } + } + + for (key, res) in entries_right.iter() { + if !entries_left.contains_key(key) { + entries_left.insert(key.clone(), res.clone()); + } + } + } + } + } +} + +/// This is the range entry for a bucket, which contains a key, count, and optionally +/// sub_aggregations. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct IntermediateRangeBucketEntry { + /// The unique the bucket is identified. + pub key: Key, + /// The number of documents in the bucket. + pub doc_count: u64, + pub(crate) values: Option>, + /// The sub_aggregation in this bucket. + pub sub_aggregation: IntermediateAggregationResults, + /// The from range of the bucket. Equals f64::MIN when None. + #[serde(skip_serializing_if = "Option::is_none")] + pub from: Option, + /// The to range of the bucket. Equals f64::MAX when None. + #[serde(skip_serializing_if = "Option::is_none")] + pub to: Option, +} + +impl From for IntermediateRangeBucketEntry { + fn from(entry: SegmentRangeBucketEntry) -> Self { + let sub_aggregation = if let Some(sub_aggregation) = entry.sub_aggregation { + sub_aggregation.into() + } else { + Default::default() + }; + // let sub_aggregation = entry.sub_aggregation.into(); + + IntermediateRangeBucketEntry { + key: entry.key, + doc_count: entry.doc_count, + values: None, + sub_aggregation, + to: entry.to, + from: entry.from, + } + } +} + +impl IntermediateRangeBucketEntry { + fn merge_fruits(&mut self, other: &IntermediateRangeBucketEntry) { + self.doc_count += other.doc_count; + self.sub_aggregation.merge_fruits(&other.sub_aggregation); + } +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + + use super::*; + + fn get_sub_test_tree(data: &[(String, u64)]) -> IntermediateAggregationResults { + let mut map = HashMap::new(); + let mut buckets = HashMap::new(); + for (key, doc_count) in data { + buckets.insert( + Key::Str(key.to_string()), + IntermediateRangeBucketEntry { + key: Key::Str(key.to_string()), + doc_count: *doc_count, + values: None, + sub_aggregation: Default::default(), + from: None, + to: None, + }, + ); + } + map.insert( + "my_agg_level2".to_string(), + IntermediateAggregationResult::Bucket(IntermediateBucketResult::Range(buckets)), + ); + IntermediateAggregationResults(VecWithNames::from_entries(map.into_iter().collect())) + } + + fn get_test_tree(data: &[(String, u64, String, u64)]) -> IntermediateAggregationResults { + let mut map = HashMap::new(); + let mut buckets = HashMap::new(); + for (key, doc_count, sub_aggregation_key, sub_aggregation_count) in data { + buckets.insert( + Key::Str(key.to_string()), + IntermediateRangeBucketEntry { + key: Key::Str(key.to_string()), + doc_count: *doc_count, + values: None, + from: None, + to: None, + sub_aggregation: get_sub_test_tree(&[( + sub_aggregation_key.to_string(), + *sub_aggregation_count, + )]), + }, + ); + } + map.insert( + "my_agg_level1".to_string(), + IntermediateAggregationResult::Bucket(IntermediateBucketResult::Range(buckets)), + ); + IntermediateAggregationResults(VecWithNames::from_entries(map.into_iter().collect())) + } + + #[test] + fn test_merge_fruits_tree_1() { + let mut tree_left = get_test_tree(&[ + ("red".to_string(), 50, "1900".to_string(), 25), + ("blue".to_string(), 30, "1900".to_string(), 30), + ]); + let tree_right = get_test_tree(&[ + ("red".to_string(), 60, "1900".to_string(), 30), + ("blue".to_string(), 25, "1900".to_string(), 50), + ]); + + tree_left.merge_fruits(&tree_right); + + let tree_expected = get_test_tree(&[ + ("red".to_string(), 110, "1900".to_string(), 55), + ("blue".to_string(), 55, "1900".to_string(), 80), + ]); + + assert_eq!(tree_left, tree_expected); + } + + #[test] + fn test_merge_fruits_tree_2() { + let mut tree_left = get_test_tree(&[ + ("red".to_string(), 50, "1900".to_string(), 25), + ("blue".to_string(), 30, "1900".to_string(), 30), + ]); + let tree_right = get_test_tree(&[ + ("red".to_string(), 60, "1900".to_string(), 30), + ("green".to_string(), 25, "1900".to_string(), 50), + ]); + + tree_left.merge_fruits(&tree_right); + + let tree_expected = get_test_tree(&[ + ("red".to_string(), 110, "1900".to_string(), 55), + ("blue".to_string(), 30, "1900".to_string(), 30), + ("green".to_string(), 25, "1900".to_string(), 50), + ]); + + assert_eq!(tree_left, tree_expected); + } +} diff --git a/src/aggregation/metric/average.rs b/src/aggregation/metric/average.rs new file mode 100644 index 000000000..a83ae6530 --- /dev/null +++ b/src/aggregation/metric/average.rs @@ -0,0 +1,101 @@ +use std::fmt::Debug; + +use serde::{Deserialize, Serialize}; + +use crate::aggregation::f64_from_fastfield_u64; +use crate::fastfield::{DynamicFastFieldReader, FastFieldReader}; +use crate::schema::Type; +use crate::DocId; + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +/// A single-value metric aggregation that computes the average of numeric values that are +/// extracted from the aggregated documents. +/// Supported field types are u64, i64, and f64. +/// See [super::SingleMetricResult] for return value. +pub struct AverageAggregation { + /// The field name to compute the stats on. + pub field: String, +} +impl AverageAggregation { + /// Create new AverageAggregation from a field. + pub fn from_field_name(field_name: String) -> Self { + AverageAggregation { field: field_name } + } + /// Return the field name. + pub fn field_name(&self) -> &str { + &self.field + } +} + +#[derive(Clone, PartialEq)] +pub(crate) struct SegmentAverageCollector { + pub data: IntermediateAverage, + field_type: Type, +} + +impl Debug for SegmentAverageCollector { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AverageCollector") + .field("data", &self.data) + .finish() + } +} + +impl SegmentAverageCollector { + pub fn from_req(field_type: Type) -> Self { + Self { + field_type, + data: Default::default(), + } + } + pub(crate) fn collect_block(&mut self, doc: &[DocId], field: &DynamicFastFieldReader) { + let mut iter = doc.chunks_exact(4); + for docs in iter.by_ref() { + let val1 = field.get(docs[0]); + let val2 = field.get(docs[1]); + let val3 = field.get(docs[2]); + let val4 = field.get(docs[3]); + let val1 = f64_from_fastfield_u64(val1, &self.field_type); + let val2 = f64_from_fastfield_u64(val2, &self.field_type); + let val3 = f64_from_fastfield_u64(val3, &self.field_type); + let val4 = f64_from_fastfield_u64(val4, &self.field_type); + self.data.collect(val1); + self.data.collect(val2); + self.data.collect(val3); + self.data.collect(val4); + } + for doc in iter.remainder() { + let val = field.get(*doc); + let val = f64_from_fastfield_u64(val, &self.field_type); + self.data.collect(val); + } + } +} + +/// Contains mergeable version of average data. +#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct IntermediateAverage { + pub(crate) sum: f64, + pub(crate) doc_count: u64, +} + +impl IntermediateAverage { + pub(crate) fn from_collector(collector: SegmentAverageCollector) -> Self { + collector.data + } + + /// Merge average data into this instance. + pub fn merge_fruits(&mut self, other: &IntermediateAverage) { + self.sum += other.sum; + self.doc_count += other.doc_count; + } + /// compute final result + pub fn finalize(&self) -> f64 { + self.sum / self.doc_count as f64 + } + #[inline] + fn collect(&mut self, val: f64) { + self.doc_count += 1; + self.sum += val; + } +} diff --git a/src/aggregation/metric/mod.rs b/src/aggregation/metric/mod.rs new file mode 100644 index 000000000..e7260ac30 --- /dev/null +++ b/src/aggregation/metric/mod.rs @@ -0,0 +1,22 @@ +//! Module for all metric aggregations. + +mod average; +mod stats; +pub use average::*; +use serde::{Deserialize, Serialize}; +pub use stats::*; + +/// Single-metric aggregations use this common result structure. +/// +/// Main reason to wrap it in value is to match elasticsearch output structure. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct SingleMetricResult { + /// The value of the single value metric. + pub value: f64, +} + +impl From for SingleMetricResult { + fn from(value: f64) -> Self { + Self { value } + } +} diff --git a/src/aggregation/metric/stats.rs b/src/aggregation/metric/stats.rs new file mode 100644 index 000000000..90633510d --- /dev/null +++ b/src/aggregation/metric/stats.rs @@ -0,0 +1,273 @@ +use serde::{Deserialize, Serialize}; + +use crate::aggregation::f64_from_fastfield_u64; +use crate::fastfield::{DynamicFastFieldReader, FastFieldReader}; +use crate::schema::Type; +use crate::DocId; + +/// A multi-value metric aggregation that computes stats of numeric values that are +/// extracted from the aggregated documents. +/// Supported field types are u64, i64, and f64. +/// See [Stats] for returned statistics. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct StatsAggregation { + /// The field name to compute the stats on. + pub field: String, +} +impl StatsAggregation { + /// Create new StatsAggregation from a field. + pub fn from_field_name(field_name: String) -> Self { + StatsAggregation { field: field_name } + } + /// Return the field name. + pub fn field_name(&self) -> &str { + &self.field + } +} + +/// Stats contains a collection of statistics. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct Stats { + /// The number of documents. + pub count: usize, + /// The sum of the fast field values. + pub sum: f64, + /// The standard deviation of the fast field values. + pub standard_deviation: f64, + /// The min value of the fast field values. + pub min: f64, + /// The max value of the fast field values. + pub max: f64, + /// The average of the values. + pub avg: f64, +} + +/// IntermediateStats contains the mergeable version for stats. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct IntermediateStats { + count: usize, + sum: f64, + squared_sum: f64, + min: f64, + max: f64, +} + +impl IntermediateStats { + fn new() -> Self { + Self { + count: 0, + sum: 0.0, + squared_sum: 0.0, + min: f64::MAX, + max: f64::MIN, + } + } + + pub(crate) fn avg(&self) -> f64 { + self.sum / (self.count as f64) + } + + fn square_mean(&self) -> f64 { + self.squared_sum / (self.count as f64) + } + + pub(crate) fn standard_deviation(&self) -> f64 { + let average = self.avg(); + (self.square_mean() - average * average).sqrt() + } + + /// Merge data from other stats into this instance. + pub fn merge_fruits(&mut self, other: &IntermediateStats) { + self.count += other.count; + self.sum += other.sum; + self.squared_sum += other.squared_sum; + self.min = self.min.min(other.min); + self.max = self.max.max(other.max); + } + + /// compute final result + pub fn finalize(&self) -> Stats { + Stats { + count: self.count, + sum: self.sum, + standard_deviation: self.standard_deviation(), + min: self.min, + max: self.max, + avg: self.avg(), + } + } + + #[inline] + fn collect(&mut self, value: f64) { + self.count += 1; + self.sum += value; + self.squared_sum += value * value; + self.min = self.min.min(value); + self.max = self.max.max(value); + } +} + +#[derive(Clone, Debug, PartialEq)] +pub(crate) struct SegmentStatsCollector { + pub(crate) stats: IntermediateStats, + field_type: Type, +} + +impl SegmentStatsCollector { + pub fn from_req(field_type: Type) -> Self { + Self { + field_type, + stats: IntermediateStats::new(), + } + } + pub(crate) fn collect_block(&mut self, doc: &[DocId], field: &DynamicFastFieldReader) { + let mut iter = doc.chunks_exact(4); + for docs in iter.by_ref() { + let val1 = field.get(docs[0]); + let val2 = field.get(docs[1]); + let val3 = field.get(docs[2]); + let val4 = field.get(docs[3]); + let val1 = f64_from_fastfield_u64(val1, &self.field_type); + let val2 = f64_from_fastfield_u64(val2, &self.field_type); + let val3 = f64_from_fastfield_u64(val3, &self.field_type); + let val4 = f64_from_fastfield_u64(val4, &self.field_type); + self.stats.collect(val1); + self.stats.collect(val2); + self.stats.collect(val3); + self.stats.collect(val4); + } + for doc in iter.remainder() { + let val = field.get(*doc); + let val = f64_from_fastfield_u64(val, &self.field_type); + self.stats.collect(val); + } + } +} + +#[cfg(test)] +mod tests { + + use std::iter; + + use serde_json::Value; + + use crate::aggregation::agg_req::{ + Aggregation, Aggregations, BucketAggregation, BucketAggregationType, MetricAggregation, + RangeAggregation, + }; + use crate::aggregation::agg_result::AggregationResults; + use crate::aggregation::metric::StatsAggregation; + use crate::aggregation::tests::get_test_index_2_segments; + use crate::aggregation::AggregationCollector; + use crate::query::TermQuery; + use crate::schema::IndexRecordOption; + use crate::Term; + + #[test] + fn test_aggregation_stats() -> crate::Result<()> { + let index = get_test_index_2_segments(false)?; + + let reader = index.reader()?; + let text_field = reader.searcher().schema().get_field("text").unwrap(); + + let term_query = TermQuery::new( + Term::from_field_text(text_field, "cool"), + IndexRecordOption::Basic, + ); + + let agg_req_1: Aggregations = vec![ + ( + "stats_i64".to_string(), + Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name( + "score_i64".to_string(), + ))), + ), + ( + "stats_f64".to_string(), + Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name( + "score_f64".to_string(), + ))), + ), + ( + "stats".to_string(), + Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name( + "score".to_string(), + ))), + ), + ( + "range".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Range(RangeAggregation { + field: "score".to_string(), + ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()], + }), + sub_aggregation: iter::once(( + "stats".to_string(), + Aggregation::Metric(MetricAggregation::Stats( + StatsAggregation::from_field_name("score".to_string()), + )), + )) + .collect(), + }), + ), + ] + .into_iter() + .collect(); + + let collector = AggregationCollector::from_aggs(agg_req_1); + + let searcher = reader.searcher(); + let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap(); + + let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?; + assert_eq!( + res["stats"], + json!({ + "avg": 12.142857142857142, + "count": 7, + "max": 44.0, + "min": 1.0, + "standard_deviation": 13.65313748796613, + "sum": 85.0 + }) + ); + + assert_eq!( + res["stats_i64"], + json!({ + "avg": 12.142857142857142, + "count": 7, + "max": 44.0, + "min": 1.0, + "standard_deviation": 13.65313748796613, + "sum": 85.0 + }) + ); + + assert_eq!( + res["stats_f64"], + json!({ + "avg": 12.214285714285714, + "count": 7, + "max": 44.5, + "min": 1.0, + "standard_deviation": 13.819905785437443, + "sum": 85.5 + }) + ); + + assert_eq!( + res["range"]["buckets"][2]["stats"], + json!({ + "avg": 10.666666666666666, + "count": 3, + "max": 14.0, + "min": 7.0, + "standard_deviation": 2.867441755680877, + "sum": 32.0 + }) + ); + + Ok(()) + } +} diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs new file mode 100644 index 000000000..493c0e2e8 --- /dev/null +++ b/src/aggregation/mod.rs @@ -0,0 +1,1148 @@ +//! # Aggregations +//! +//! +//! Aggregation summarizes your data as statistics on buckets or metrics. +//! +//! Aggregations can provide answer to questions like: +//! - What is the average price of all sold articles? +//! - How many errors with status code 500 do we have per day? +//! - What is the average listing price of cars grouped by color? +//! +//! There are two categories: [Metrics](metric) and [Buckets](bucket). +//! +//! # Usage +//! +//! +//! To use aggregations, build an aggregation request by constructing [agg_req::Aggregations]. +//! Create an [AggregationCollector] from this request. AggregationCollector implements the +//! `Collector` trait and can be passed as collector into `searcher.search()`. +//! +//! # Example +//! Compute the average metric, by building [agg_req::Aggregations], which is built from an (String, +//! [agg_req::Aggregation]) iterator. +//! +//! ``` +//! use tantivy::aggregation::agg_req::{Aggregations, Aggregation, MetricAggregation}; +//! use tantivy::aggregation::AggregationCollector; +//! use tantivy::aggregation::metric::AverageAggregation; +//! use tantivy::query::AllQuery; +//! use tantivy::aggregation::agg_result::AggregationResults; +//! use tantivy::IndexReader; +//! +//! # #[allow(dead_code)] +//! fn aggregate_on_index(reader: &IndexReader) { +//! let agg_req: Aggregations = vec![ +//! ( +//! "average".to_string(), +//! Aggregation::Metric(MetricAggregation::Average( +//! AverageAggregation::from_field_name("score".to_string()), +//! )), +//! ), +//! ] +//! .into_iter() +//! .collect(); +//! +//! let collector = AggregationCollector::from_aggs(agg_req); +//! +//! let searcher = reader.searcher(); +//! let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap(); +//! } +//! ``` +//! # Example JSON +//! Requests are compatible with the elasticsearch json request format. +//! +//! ``` +//! use tantivy::aggregation::agg_req::Aggregations; +//! +//! let elasticsearch_compatible_json_req = r#" +//! { +//! "average": { +//! "avg": { "field": "score" } +//! }, +//! "range": { +//! "range": { +//! "field": "score", +//! "ranges": [ +//! { "to": 3.0 }, +//! { "from": 3.0, "to": 7.0 }, +//! { "from": 7.0, "to": 20.0 }, +//! { "from": 20.0 } +//! ] +//! }, +//! "aggs": { +//! "average_in_range": { "avg": { "field": "score" } } +//! } +//! } +//! } +//! "#; +//! let agg_req: Aggregations = serde_json::from_str(elasticsearch_compatible_json_req).unwrap(); +//! ``` +//! # Code Organization +//! +//! Check the [README](https://github.com/quickwit-oss/tantivy/tree/main/src/aggregation#readme) on github to see how the code is organized. +//! +//! # Nested Aggregation +//! +//! Buckets can contain sub-aggregations. In this example we create buckets with the range +//! aggregation and then calculate the average on each bucket. +//! ``` +//! use tantivy::aggregation::agg_req::{Aggregations, Aggregation, BucketAggregation, +//! MetricAggregation, BucketAggregationType}; +//! use tantivy::aggregation::metric::AverageAggregation; +//! use tantivy::aggregation::bucket::RangeAggregation; +//! let sub_agg_req_1: Aggregations = vec![( +//! "average_in_range".to_string(), +//! Aggregation::Metric(MetricAggregation::Average( +//! AverageAggregation::from_field_name("score".to_string()), +//! )), +//! )] +//! .into_iter() +//! .collect(); +//! +//! let agg_req_1: Aggregations = vec![ +//! ( +//! "range".to_string(), +//! Aggregation::Bucket(BucketAggregation { +//! bucket_agg: BucketAggregationType::Range(RangeAggregation{ +//! field: "score".to_string(), +//! ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()], +//! }), +//! sub_aggregation: sub_agg_req_1.clone(), +//! }), +//! ), +//! ] +//! .into_iter() +//! .collect(); +//! ``` +//! +//! # Distributed Aggregation +//! When the data is distributed on different [crate::Index] instances, the +//! [DistributedAggregationCollector] provides functionality to merge data between independent +//! search calls by returning +//! [IntermediateAggregationResults](intermediate_agg_result::IntermediateAggregationResults). +//! IntermediateAggregationResults provides the +//! [merge_fruits](intermediate_agg_result::IntermediateAggregationResults::merge_fruits) method to +//! merge multiple results. The merged result can then be converted into +//! [agg_result::AggregationResults] via the [Into] trait. + +pub mod agg_req; +mod agg_req_with_accessor; +pub mod agg_result; +pub mod bucket; +mod collector; +pub mod intermediate_agg_result; +pub mod metric; +mod segment_agg_result; + +use std::collections::HashMap; +use std::fmt::Display; + +pub use collector::{AggregationCollector, DistributedAggregationCollector}; +use itertools::Itertools; +use serde::{Deserialize, Serialize}; + +use crate::fastfield::FastValue; +use crate::schema::Type; + +/// Represents an associative array `(key => values)` in a very efficient manner. +#[derive(Clone, PartialEq, Serialize, Deserialize)] +pub(crate) struct VecWithNames { + values: Vec, + keys: Vec, +} +impl Default for VecWithNames { + fn default() -> Self { + Self { + values: Default::default(), + keys: Default::default(), + } + } +} + +impl std::fmt::Debug for VecWithNames { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_map().entries(self.iter()).finish() + } +} + +impl From> for VecWithNames { + fn from(map: HashMap) -> Self { + VecWithNames::from_entries(map.into_iter().collect_vec()) + } +} + +impl VecWithNames { + fn from_entries(mut entries: Vec<(String, T)>) -> Self { + // Sort to ensure order of elements match across multiple instances + entries.sort_by(|left, right| left.0.cmp(&right.0)); + let mut data = vec![]; + let mut data_names = vec![]; + for entry in entries { + data_names.push(entry.0); + data.push(entry.1); + } + VecWithNames { + values: data, + keys: data_names, + } + } + fn into_iter(self) -> impl Iterator { + self.keys.into_iter().zip(self.values.into_iter()) + } + fn iter(&self) -> impl Iterator + '_ { + self.keys().zip(self.values.iter()) + } + fn keys(&self) -> impl Iterator + '_ { + self.keys.iter().map(|key| key.as_str()) + } + fn values(&self) -> impl Iterator + '_ { + self.values.iter() + } + fn values_mut(&mut self) -> impl Iterator + '_ { + self.values.iter_mut() + } + fn entries(&self) -> impl Iterator + '_ { + self.keys().zip(self.values.iter()) + } + fn is_empty(&self) -> bool { + self.keys.is_empty() + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash, Deserialize, Ord, PartialOrd)] +/// The key to identify a bucket. +pub enum Key { + /// String key + Str(String), + /// u64 key + U64(u64), + /// i64 key + I64(i64), +} + +impl Display for Key { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Key::Str(val) => f.write_str(val), + Key::U64(val) => f.write_str(&val.to_string()), + Key::I64(val) => f.write_str(&val.to_string()), + } + } +} + +impl Serialize for Key { + fn serialize(&self, serializer: S) -> Result + where S: serde::Serializer { + serializer.serialize_str(&self.to_string()) + } +} + +/// Invert of to_fastfield_u64 +pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: &Type) -> f64 { + match field_type { + Type::U64 => val as f64, + Type::I64 => i64::from_u64(val) as f64, + Type::F64 => f64::from_u64(val), + Type::Date | Type::Str | Type::Facet | Type::Bytes => unimplemented!(), + } +} + +/// Converts the f64 value to fast field value space. +/// +/// If the fast field has u64, values are stored as u64 in the fast field. +/// A f64 value of e.g. 2.0 therefore needs to be converted to 1u64 +/// +/// If the fast field has f64 values are converted and stored to u64 using a +/// monotonic mapping. +/// A f64 value of e.g. 2.0 needs to be converted using the same monotonic +/// conversion function, so that the value matches the u64 value stored in the fast +/// field. +pub(crate) fn f64_to_fastfield_u64(val: f64, field_type: &Type) -> u64 { + match field_type { + Type::U64 => val as u64, + Type::I64 => (val as i64).to_u64(), + Type::F64 => val.to_u64(), + Type::Date | Type::Str | Type::Facet | Type::Bytes => unimplemented!(), + } +} + +#[cfg(test)] +mod tests { + + use futures::executor::block_on; + use serde_json::Value; + + use super::agg_req::{Aggregation, Aggregations, BucketAggregation}; + use super::bucket::RangeAggregation; + use super::collector::AggregationCollector; + use super::metric::AverageAggregation; + use crate::aggregation::agg_req::{BucketAggregationType, MetricAggregation}; + use crate::aggregation::agg_result::AggregationResults; + use crate::aggregation::segment_agg_result::DOC_BLOCK_SIZE; + use crate::aggregation::DistributedAggregationCollector; + use crate::query::TermQuery; + use crate::schema::{Cardinality, IndexRecordOption, Schema, TextFieldIndexing}; + use crate::{Index, Term}; + + fn get_avg_req(field_name: &str) -> Aggregation { + Aggregation::Metric(MetricAggregation::Average( + AverageAggregation::from_field_name(field_name.to_string()), + )) + } + + pub fn get_test_index_with_num_docs( + merge_segments: bool, + num_docs: usize, + ) -> crate::Result { + let mut schema_builder = Schema::builder(); + let text_fieldtype = crate::schema::TextOptions::default() + .set_indexing_options( + TextFieldIndexing::default() + .set_tokenizer("default") + .set_index_option(IndexRecordOption::WithFreqs), + ) + .set_stored(); + let text_field = schema_builder.add_text_field("text", text_fieldtype); + let score_fieldtype = + crate::schema::IntOptions::default().set_fast(Cardinality::SingleValue); + let score_field = schema_builder.add_u64_field("score", score_fieldtype.clone()); + let score_field_f64 = schema_builder.add_f64_field("score_f64", score_fieldtype.clone()); + let score_field_i64 = schema_builder.add_i64_field("score_i64", score_fieldtype); + let fraction_field = schema_builder.add_f64_field( + "fraction_f64", + crate::schema::IntOptions::default().set_fast(Cardinality::SingleValue), + ); + let index = Index::create_in_ram(schema_builder.build()); + { + let mut index_writer = index.writer_for_tests()?; + for i in 0..num_docs { + // writing the segment + index_writer.add_document(doc!( + text_field => "cool", + score_field => i as u64, + score_field_f64 => i as f64, + score_field_i64 => i as i64, + fraction_field => i as f64/100.0, + ))?; + } + + index_writer.commit()?; + } + if merge_segments { + let segment_ids = index + .searchable_segment_ids() + .expect("Searchable segments failed."); + let mut index_writer = index.writer_for_tests()?; + block_on(index_writer.merge(&segment_ids))?; + index_writer.wait_merging_threads()?; + } + + Ok(index) + } + + // *** EVERY BUCKET-TYPE SHOULD BE TESTED HERE *** + fn test_aggregation_flushing( + merge_segments: bool, + use_distributed_collector: bool, + ) -> crate::Result<()> { + let index = get_test_index_with_num_docs(merge_segments, 300)?; + + let reader = index.reader()?; + let text_field = reader.searcher().schema().get_field("text").unwrap(); + + let term_query = TermQuery::new( + Term::from_field_text(text_field, "cool"), + IndexRecordOption::Basic, + ); + + assert_eq!(DOC_BLOCK_SIZE, 256); + // In the tree we cache Documents of DOC_BLOCK_SIZE, before passing them down as one block. + // + // Build a request so that on the first level we have one full cache, which is then flushed. + // The same cache should have some residue docs at the end, which are flushed (Range 0-266) + // -> 266 docs + // + // The second level should also have some residue docs in the cache that are flushed at the + // end. + // + // A second bucket on the first level should have the cache unfilled + + // let elasticsearch_compatible_json_req = r#" + let elasticsearch_compatible_json_req = r#" + { + "bucketsL1": { + "range": { + "field": "score", + "ranges": [ { "to": 3.0 }, { "from": 3.0, "to": 266.0 }, { "from": 266.0 } ] + }, + "aggs": { + "bucketsL2": { + "range": { + "field": "score", + "ranges": [ { "to": 100.0 }, { "from": 100.0, "to": 266.0 }, { "from": 266.0 } ] + } + } + } + } + } + "#; + + let agg_req: Aggregations = + serde_json::from_str(elasticsearch_compatible_json_req).unwrap(); + + let agg_res: AggregationResults = if use_distributed_collector { + let collector = DistributedAggregationCollector::from_aggs(agg_req); + + let searcher = reader.searcher(); + searcher.search(&term_query, &collector).unwrap().into() + } else { + let collector = AggregationCollector::from_aggs(agg_req); + + let searcher = reader.searcher(); + searcher.search(&term_query, &collector).unwrap() + }; + + let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?; + + assert_eq!(res["bucketsL1"]["buckets"][0]["doc_count"], 3); + assert_eq!( + res["bucketsL1"]["buckets"][0]["bucketsL2"]["buckets"][0]["doc_count"], + 3 + ); + assert_eq!(res["bucketsL1"]["buckets"][1]["key"], "3-266"); + assert_eq!(res["bucketsL1"]["buckets"][1]["doc_count"], 266 - 3); + assert_eq!( + res["bucketsL1"]["buckets"][1]["bucketsL2"]["buckets"][0]["doc_count"], + 97 + ); + assert_eq!( + res["bucketsL1"]["buckets"][1]["bucketsL2"]["buckets"][1]["doc_count"], + 166 + ); + assert_eq!( + res["bucketsL1"]["buckets"][1]["bucketsL2"]["buckets"][2]["doc_count"], + 0 + ); + assert_eq!( + res["bucketsL1"]["buckets"][2]["bucketsL2"]["buckets"][2]["doc_count"], + 300 - 266 + ); + assert_eq!(res["bucketsL1"]["buckets"][2]["doc_count"], 300 - 266); + + Ok(()) + } + + #[test] + fn test_aggregation_flushing_variants() { + test_aggregation_flushing(false, false).unwrap(); + test_aggregation_flushing(false, true).unwrap(); + test_aggregation_flushing(true, false).unwrap(); + test_aggregation_flushing(true, true).unwrap(); + } + + pub fn get_test_index_2_segments(merge_segments: bool) -> crate::Result { + let mut schema_builder = Schema::builder(); + let text_fieldtype = crate::schema::TextOptions::default() + .set_indexing_options( + TextFieldIndexing::default() + .set_tokenizer("default") + .set_index_option(IndexRecordOption::WithFreqs), + ) + .set_stored(); + let text_field = schema_builder.add_text_field("text", text_fieldtype); + let score_fieldtype = + crate::schema::IntOptions::default().set_fast(Cardinality::SingleValue); + let score_field = schema_builder.add_u64_field("score", score_fieldtype.clone()); + let score_field_f64 = schema_builder.add_f64_field("score_f64", score_fieldtype.clone()); + let score_field_i64 = schema_builder.add_i64_field("score_i64", score_fieldtype); + let index = Index::create_in_ram(schema_builder.build()); + { + let mut index_writer = index.writer_for_tests()?; + // writing the segment + index_writer.add_document(doc!( + text_field => "cool", + score_field => 1u64, + score_field_f64 => 1f64, + score_field_i64 => 1i64, + ))?; + index_writer.add_document(doc!( + text_field => "cool", + score_field => 3u64, + score_field_f64 => 3f64, + score_field_i64 => 3i64, + ))?; + index_writer.add_document(doc!( + text_field => "cool", + score_field => 5u64, + score_field_f64 => 5f64, + score_field_i64 => 5i64, + ))?; + index_writer.add_document(doc!( + text_field => "nohit", + score_field => 6u64, + score_field_f64 => 6f64, + score_field_i64 => 6i64, + ))?; + index_writer.add_document(doc!( + text_field => "cool", + score_field => 7u64, + score_field_f64 => 7f64, + score_field_i64 => 7i64, + ))?; + index_writer.commit()?; + index_writer.add_document(doc!( + text_field => "cool", + score_field => 11u64, + score_field_f64 => 11f64, + score_field_i64 => 11i64, + ))?; + index_writer.add_document(doc!( + text_field => "cool", + score_field => 14u64, + score_field_f64 => 14f64, + score_field_i64 => 14i64, + ))?; + + index_writer.add_document(doc!( + text_field => "cool", + score_field => 44u64, + score_field_f64 => 44.5f64, + score_field_i64 => 44i64, + ))?; + + index_writer.commit()?; + + // no hits segment + index_writer.add_document(doc!( + text_field => "nohit", + score_field => 44u64, + score_field_f64 => 44.5f64, + score_field_i64 => 44i64, + ))?; + + index_writer.commit()?; + } + if merge_segments { + let segment_ids = index + .searchable_segment_ids() + .expect("Searchable segments failed."); + let mut index_writer = index.writer_for_tests()?; + block_on(index_writer.merge(&segment_ids))?; + index_writer.wait_merging_threads()?; + } + + Ok(index) + } + + #[test] + fn test_aggregation_level1() -> crate::Result<()> { + let index = get_test_index_2_segments(true)?; + + let reader = index.reader()?; + let text_field = reader.searcher().schema().get_field("text").unwrap(); + + let term_query = TermQuery::new( + Term::from_field_text(text_field, "cool"), + IndexRecordOption::Basic, + ); + + let agg_req_1: Aggregations = vec![ + ("average_i64".to_string(), get_avg_req("score_i64")), + ("average_f64".to_string(), get_avg_req("score_f64")), + ("average".to_string(), get_avg_req("score")), + ( + "range".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Range(RangeAggregation { + field: "score".to_string(), + ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()], + }), + sub_aggregation: Default::default(), + }), + ), + ( + "rangef64".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Range(RangeAggregation { + field: "score_f64".to_string(), + ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()], + }), + sub_aggregation: Default::default(), + }), + ), + ( + "rangei64".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Range(RangeAggregation { + field: "score_i64".to_string(), + ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()], + }), + sub_aggregation: Default::default(), + }), + ), + ] + .into_iter() + .collect(); + + let collector = AggregationCollector::from_aggs(agg_req_1); + + let searcher = reader.searcher(); + let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap(); + + let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?; + assert_eq!(res["average"]["value"], 12.142857142857142); + assert_eq!(res["average_f64"]["value"], 12.214285714285714); + assert_eq!(res["average_i64"]["value"], 12.142857142857142); + assert_eq!( + res["range"]["buckets"], + json!( + [ + { + "key": "*-3", + "doc_count": 1, + "to": 3.0 + }, + { + "key": "3-7", + "doc_count": 2, + "from": 3.0, + "to": 7.0 + }, + { + "key": "7-20", + "doc_count": 3, + "from": 7.0, + "to": 20.0 + }, + { + "key": "20-*", + "doc_count": 1, + "from": 20.0 + } + ]) + ); + + Ok(()) + } + + fn test_aggregation_level2( + merge_segments: bool, + use_distributed_collector: bool, + use_elastic_json_req: bool, + ) -> crate::Result<()> { + let index = get_test_index_2_segments(merge_segments)?; + + let reader = index.reader()?; + let text_field = reader.searcher().schema().get_field("text").unwrap(); + + let term_query = TermQuery::new( + Term::from_field_text(text_field, "cool"), + IndexRecordOption::Basic, + ); + + let sub_agg_req: Aggregations = + vec![("average_in_range".to_string(), get_avg_req("score"))] + .into_iter() + .collect(); + let agg_req: Aggregations = if use_elastic_json_req { + let elasticsearch_compatible_json_req = r#" +{ + "rangef64": { + "range": { + "field": "score_f64", + "ranges": [ + { "to": 3.0 }, + { "from": 3.0, "to": 7.0 }, + { "from": 7.0, "to": 20.0 }, + { "from": 20.0 } + ] + }, + "aggs": { + "average_in_range": { "avg": { "field": "score" } } + } + }, + "rangei64": { + "range": { + "field": "score_i64", + "ranges": [ + { "to": 3.0 }, + { "from": 3.0, "to": 7.0 }, + { "from": 7.0, "to": 20.0 }, + { "from": 20.0 } + ] + }, + "aggs": { + "average_in_range": { "avg": { "field": "score" } } + } + }, + "average": { + "avg": { "field": "score" } + }, + "range": { + "range": { + "field": "score", + "ranges": [ + { "to": 3.0 }, + { "from": 3.0, "to": 7.0 }, + { "from": 7.0, "to": 20.0 }, + { "from": 20.0 } + ] + }, + "aggs": { + "average_in_range": { "avg": { "field": "score" } } + } + } +} +"#; + let value: Aggregations = + serde_json::from_str(elasticsearch_compatible_json_req).unwrap(); + value + } else { + let agg_req: Aggregations = vec![ + ("average".to_string(), get_avg_req("score")), + ( + "range".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Range(RangeAggregation { + field: "score".to_string(), + ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()], + }), + sub_aggregation: sub_agg_req.clone(), + }), + ), + ( + "rangef64".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Range(RangeAggregation { + field: "score_f64".to_string(), + ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()], + }), + sub_aggregation: sub_agg_req.clone(), + }), + ), + ( + "rangei64".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Range(RangeAggregation { + field: "score_i64".to_string(), + ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()], + }), + sub_aggregation: sub_agg_req, + }), + ), + ] + .into_iter() + .collect(); + agg_req + }; + + let agg_res: AggregationResults = if use_distributed_collector { + let collector = DistributedAggregationCollector::from_aggs(agg_req); + + let searcher = reader.searcher(); + searcher.search(&term_query, &collector).unwrap().into() + } else { + let collector = AggregationCollector::from_aggs(agg_req); + + let searcher = reader.searcher(); + searcher.search(&term_query, &collector).unwrap() + }; + + let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?; + + assert_eq!(res["range"]["buckets"][1]["key"], "3-7"); + assert_eq!(res["range"]["buckets"][1]["doc_count"], 2u64); + assert_eq!(res["rangef64"]["buckets"][1]["doc_count"], 2u64); + assert_eq!(res["rangei64"]["buckets"][1]["doc_count"], 2u64); + + assert_eq!(res["average"]["value"], 12.142857142857142f64); + assert_eq!(res["range"]["buckets"][2]["key"], "7-20"); + assert_eq!(res["range"]["buckets"][2]["doc_count"], 3u64); + assert_eq!(res["rangef64"]["buckets"][2]["doc_count"], 3u64); + assert_eq!(res["rangei64"]["buckets"][2]["doc_count"], 3u64); + assert_eq!(res["rangei64"]["buckets"][4], serde_json::Value::Null); + + assert_eq!(res["range"]["buckets"][3]["key"], "20-*"); + assert_eq!(res["range"]["buckets"][3]["doc_count"], 1u64); + assert_eq!(res["rangef64"]["buckets"][3]["doc_count"], 1u64); + assert_eq!(res["rangei64"]["buckets"][3]["doc_count"], 1u64); + + assert_eq!( + res["range"]["buckets"][3]["average_in_range"]["value"], + 44.0f64 + ); + assert_eq!( + res["rangef64"]["buckets"][3]["average_in_range"]["value"], + 44.0f64 + ); + assert_eq!( + res["rangei64"]["buckets"][3]["average_in_range"]["value"], + 44.0f64 + ); + + assert_eq!( + res["range"]["7-20"]["average_in_range"]["value"], + res["rangef64"]["7-20"]["average_in_range"]["value"] + ); + assert_eq!( + res["range"]["7-20"]["average_in_range"]["value"], + res["rangei64"]["7-20"]["average_in_range"]["value"] + ); + + Ok(()) + } + + #[test] + fn test_aggregation_level2_multi_segments() -> crate::Result<()> { + test_aggregation_level2(false, false, false) + } + + #[test] + fn test_aggregation_level2_single_segment() -> crate::Result<()> { + test_aggregation_level2(true, false, false) + } + + #[test] + fn test_aggregation_level2_multi_segments_distributed_collector() -> crate::Result<()> { + test_aggregation_level2(false, true, false) + } + + #[test] + fn test_aggregation_level2_single_segment_distributed_collector() -> crate::Result<()> { + test_aggregation_level2(true, true, false) + } + + #[test] + fn test_aggregation_level2_multi_segments_use_json() -> crate::Result<()> { + test_aggregation_level2(false, false, true) + } + + #[test] + fn test_aggregation_level2_single_segment_use_json() -> crate::Result<()> { + test_aggregation_level2(true, false, true) + } + + #[test] + fn test_aggregation_level2_multi_segments_distributed_collector_use_json() -> crate::Result<()> + { + test_aggregation_level2(false, true, true) + } + + #[test] + fn test_aggregation_level2_single_segment_distributed_collector_use_json() -> crate::Result<()> + { + test_aggregation_level2(true, true, true) + } + + #[test] + fn test_aggregation_invalid_requests() -> crate::Result<()> { + let index = get_test_index_2_segments(false)?; + + let reader = index.reader()?; + let text_field = reader.searcher().schema().get_field("text").unwrap(); + + let term_query = TermQuery::new( + Term::from_field_text(text_field, "cool"), + IndexRecordOption::Basic, + ); + + let agg_req_1: Aggregations = vec![( + "average".to_string(), + Aggregation::Metric(MetricAggregation::Average( + AverageAggregation::from_field_name("text".to_string()), + )), + )] + .into_iter() + .collect(); + + let collector = AggregationCollector::from_aggs(agg_req_1); + + let searcher = reader.searcher(); + let agg_res = searcher.search(&term_query, &collector).unwrap_err(); + + assert_eq!( + format!("{:?}", agg_res), + r#"InvalidArgument("Invalid field type in aggregation Str, only f64, u64, i64 is supported")"# + ); + Ok(()) + } + + #[cfg(all(test, feature = "unstable"))] + mod bench { + + use rand::{thread_rng, Rng}; + use test::{self, Bencher}; + + use super::*; + use crate::aggregation::metric::StatsAggregation; + use crate::query::AllQuery; + + fn get_test_index_bench(merge_segments: bool) -> crate::Result { + let mut schema_builder = Schema::builder(); + let text_fieldtype = crate::schema::TextOptions::default() + .set_indexing_options( + TextFieldIndexing::default() + .set_tokenizer("default") + .set_index_option(IndexRecordOption::WithFreqs), + ) + .set_stored(); + let text_field = schema_builder.add_text_field("text", text_fieldtype); + let score_fieldtype = + crate::schema::IntOptions::default().set_fast(Cardinality::SingleValue); + let score_field = schema_builder.add_u64_field("score", score_fieldtype.clone()); + let score_field_f64 = + schema_builder.add_f64_field("score_f64", score_fieldtype.clone()); + let score_field_i64 = schema_builder.add_i64_field("score_i64", score_fieldtype); + let index = Index::create_from_tempdir(schema_builder.build())?; + { + let mut rng = thread_rng(); + let mut index_writer = index.writer_for_tests()?; + // writing the segment + for _ in 0..1_000_000 { + let val: f64 = rng.gen_range(0.0..1_000_000.0); + index_writer.add_document(doc!( + text_field => "cool", + score_field => val as u64, + score_field_f64 => val as f64, + score_field_i64 => val as i64, + ))?; + } + index_writer.commit()?; + } + if merge_segments { + let segment_ids = index + .searchable_segment_ids() + .expect("Searchable segments failed."); + let mut index_writer = index.writer_for_tests()?; + block_on(index_writer.merge(&segment_ids))?; + index_writer.wait_merging_threads()?; + } + + Ok(index) + } + + #[bench] + fn bench_aggregation_average_u64(b: &mut Bencher) { + let index = get_test_index_bench(false).unwrap(); + let reader = index.reader().unwrap(); + let text_field = reader.searcher().schema().get_field("text").unwrap(); + + b.iter(|| { + let term_query = TermQuery::new( + Term::from_field_text(text_field, "cool"), + IndexRecordOption::Basic, + ); + + let agg_req_1: Aggregations = vec![( + "average".to_string(), + Aggregation::Metric(MetricAggregation::Average( + AverageAggregation::from_field_name("score".to_string()), + )), + )] + .into_iter() + .collect(); + + let collector = AggregationCollector::from_aggs(agg_req_1); + + let searcher = reader.searcher(); + let agg_res: AggregationResults = + searcher.search(&term_query, &collector).unwrap().into(); + + agg_res + }); + } + + #[bench] + fn bench_aggregation_stats_f64(b: &mut Bencher) { + let index = get_test_index_bench(false).unwrap(); + let reader = index.reader().unwrap(); + let text_field = reader.searcher().schema().get_field("text").unwrap(); + + b.iter(|| { + let term_query = TermQuery::new( + Term::from_field_text(text_field, "cool"), + IndexRecordOption::Basic, + ); + + let agg_req_1: Aggregations = vec![( + "average_f64".to_string(), + Aggregation::Metric(MetricAggregation::Stats( + StatsAggregation::from_field_name("score_f64".to_string()), + )), + )] + .into_iter() + .collect(); + + let collector = AggregationCollector::from_aggs(agg_req_1); + + let searcher = reader.searcher(); + let agg_res: AggregationResults = + searcher.search(&term_query, &collector).unwrap().into(); + + agg_res + }); + } + + #[bench] + fn bench_aggregation_average_f64(b: &mut Bencher) { + let index = get_test_index_bench(false).unwrap(); + let reader = index.reader().unwrap(); + let text_field = reader.searcher().schema().get_field("text").unwrap(); + + b.iter(|| { + let term_query = TermQuery::new( + Term::from_field_text(text_field, "cool"), + IndexRecordOption::Basic, + ); + + let agg_req_1: Aggregations = vec![( + "average_f64".to_string(), + Aggregation::Metric(MetricAggregation::Average( + AverageAggregation::from_field_name("score_f64".to_string()), + )), + )] + .into_iter() + .collect(); + + let collector = AggregationCollector::from_aggs(agg_req_1); + + let searcher = reader.searcher(); + let agg_res: AggregationResults = + searcher.search(&term_query, &collector).unwrap().into(); + + agg_res + }); + } + + #[bench] + fn bench_aggregation_average_u64_and_f64(b: &mut Bencher) { + let index = get_test_index_bench(false).unwrap(); + let reader = index.reader().unwrap(); + let text_field = reader.searcher().schema().get_field("text").unwrap(); + + b.iter(|| { + let term_query = TermQuery::new( + Term::from_field_text(text_field, "cool"), + IndexRecordOption::Basic, + ); + + let agg_req_1: Aggregations = vec![ + ( + "average_f64".to_string(), + Aggregation::Metric(MetricAggregation::Average( + AverageAggregation::from_field_name("score_f64".to_string()), + )), + ), + ( + "average".to_string(), + Aggregation::Metric(MetricAggregation::Average( + AverageAggregation::from_field_name("score".to_string()), + )), + ), + ] + .into_iter() + .collect(); + + let collector = AggregationCollector::from_aggs(agg_req_1); + + let searcher = reader.searcher(); + let agg_res: AggregationResults = + searcher.search(&term_query, &collector).unwrap().into(); + + agg_res + }); + } + + #[bench] + fn bench_aggregation_range_only(b: &mut Bencher) { + let index = get_test_index_bench(false).unwrap(); + let reader = index.reader().unwrap(); + + b.iter(|| { + let agg_req_1: Aggregations = vec![( + "rangef64".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Range(RangeAggregation { + field: "score_f64".to_string(), + ranges: vec![ + (3f64..7000f64).into(), + (7000f64..20000f64).into(), + (20000f64..30000f64).into(), + (30000f64..40000f64).into(), + (40000f64..50000f64).into(), + (50000f64..60000f64).into(), + ], + }), + sub_aggregation: Default::default(), + }), + )] + .into_iter() + .collect(); + + let collector = AggregationCollector::from_aggs(agg_req_1); + + let searcher = reader.searcher(); + let agg_res: AggregationResults = + searcher.search(&AllQuery, &collector).unwrap().into(); + + agg_res + }); + } + + #[bench] + fn bench_aggregation_sub_tree(b: &mut Bencher) { + let index = get_test_index_bench(false).unwrap(); + let reader = index.reader().unwrap(); + let text_field = reader.searcher().schema().get_field("text").unwrap(); + + b.iter(|| { + let term_query = TermQuery::new( + Term::from_field_text(text_field, "cool"), + IndexRecordOption::Basic, + ); + + let sub_agg_req_1: Aggregations = vec![( + "average_in_range".to_string(), + Aggregation::Metric(MetricAggregation::Average( + AverageAggregation::from_field_name("score".to_string()), + )), + )] + .into_iter() + .collect(); + + let agg_req_1: Aggregations = vec![ + ( + "average".to_string(), + Aggregation::Metric(MetricAggregation::Average( + AverageAggregation::from_field_name("score".to_string()), + )), + ), + ( + "rangef64".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Range(RangeAggregation { + field: "score_f64".to_string(), + ranges: vec![ + (3f64..7000f64).into(), + (7000f64..20000f64).into(), + (20000f64..60000f64).into(), + ], + }), + sub_aggregation: sub_agg_req_1.clone(), + }), + ), + ] + .into_iter() + .collect(); + + let collector = AggregationCollector::from_aggs(agg_req_1); + + let searcher = reader.searcher(); + let agg_res: AggregationResults = + searcher.search(&term_query, &collector).unwrap().into(); + + agg_res + }); + } + } +} diff --git a/src/aggregation/segment_agg_result.rs b/src/aggregation/segment_agg_result.rs new file mode 100644 index 000000000..38ea42d24 --- /dev/null +++ b/src/aggregation/segment_agg_result.rs @@ -0,0 +1,195 @@ +//! Contains aggregation trees which is used during collection in a segment. +//! This tree contains datastructrues optimized for fast collection. +//! The tree can be converted to an intermediate tree, which contains datastructrues optimized for +//! merging. + +use std::fmt::Debug; + +use itertools::Itertools; + +use super::agg_req::MetricAggregation; +use super::agg_req_with_accessor::{ + AggregationsWithAccessor, BucketAggregationWithAccessor, MetricAggregationWithAccessor, +}; +use super::bucket::SegmentRangeCollector; +use super::metric::{ + AverageAggregation, SegmentAverageCollector, SegmentStatsCollector, StatsAggregation, +}; +use super::{Key, VecWithNames}; +use crate::aggregation::agg_req::BucketAggregationType; +use crate::DocId; + +pub(crate) const DOC_BLOCK_SIZE: usize = 256; +pub(crate) type DocBlock = [DocId; DOC_BLOCK_SIZE]; + +#[derive(Clone, PartialEq)] +pub(crate) struct SegmentAggregationResultsCollector { + pub(crate) metrics: VecWithNames, + pub(crate) buckets: VecWithNames, + staged_docs: DocBlock, + num_staged_docs: usize, +} + +impl Debug for SegmentAggregationResultsCollector { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SegmentAggregationResultsCollector") + .field("metrics", &self.metrics) + .field("buckets", &self.buckets) + .field("staged_docs", &&self.staged_docs[..self.num_staged_docs]) + .field("num_staged_docs", &self.num_staged_docs) + .finish() + } +} + +impl SegmentAggregationResultsCollector { + pub(crate) fn from_req(req: &AggregationsWithAccessor) -> crate::Result { + let buckets = req + .buckets + .entries() + .map(|(key, req)| { + Ok(( + key.to_string(), + SegmentBucketResultCollector::from_req(req)?, + )) + }) + .collect::>()?; + let metrics = req + .metrics + .entries() + .map(|(key, req)| (key.to_string(), SegmentMetricResultCollector::from_req(req))) + .collect_vec(); + Ok(SegmentAggregationResultsCollector { + metrics: VecWithNames::from_entries(metrics), + buckets: VecWithNames::from_entries(buckets), + staged_docs: [0; DOC_BLOCK_SIZE], + num_staged_docs: 0, + }) + } + + #[inline] + pub(crate) fn collect( + &mut self, + doc: crate::DocId, + agg_with_accessor: &AggregationsWithAccessor, + ) { + self.staged_docs[self.num_staged_docs] = doc; + self.num_staged_docs += 1; + if self.num_staged_docs == self.staged_docs.len() { + self.flush_staged_docs(agg_with_accessor, false); + } + } + + #[inline(never)] + pub(crate) fn flush_staged_docs( + &mut self, + agg_with_accessor: &AggregationsWithAccessor, + force_flush: bool, + ) { + for (agg_with_accessor, collector) in agg_with_accessor + .metrics + .values() + .zip(self.metrics.values_mut()) + { + collector.collect_block(&self.staged_docs[..self.num_staged_docs], agg_with_accessor); + } + for (agg_with_accessor, collector) in agg_with_accessor + .buckets + .values() + .zip(self.buckets.values_mut()) + { + collector.collect_block( + &self.staged_docs[..self.num_staged_docs], + agg_with_accessor, + force_flush, + ); + } + + self.num_staged_docs = 0; + } +} + +#[derive(Clone, Debug, PartialEq)] +pub(crate) enum SegmentMetricResultCollector { + Average(SegmentAverageCollector), + Stats(SegmentStatsCollector), +} + +impl SegmentMetricResultCollector { + pub fn from_req(req: &MetricAggregationWithAccessor) -> Self { + match &req.metric { + MetricAggregation::Average(AverageAggregation { field: _ }) => { + SegmentMetricResultCollector::Average(SegmentAverageCollector::from_req( + req.field_type, + )) + } + MetricAggregation::Stats(StatsAggregation { field: _ }) => { + SegmentMetricResultCollector::Stats(SegmentStatsCollector::from_req(req.field_type)) + } + } + } + pub(crate) fn collect_block(&mut self, doc: &[DocId], metric: &MetricAggregationWithAccessor) { + match self { + SegmentMetricResultCollector::Average(avg_collector) => { + avg_collector.collect_block(doc, &metric.accessor); + } + SegmentMetricResultCollector::Stats(stats_collector) => { + stats_collector.collect_block(doc, &metric.accessor); + } + } + } +} + +/// SegmentBucketAggregationResultCollectors will have specialized buckets for collection inside +/// segments. +/// The typical structure of Map is not suitable during collection for performance +/// reasons. +#[derive(Clone, Debug, PartialEq)] +pub(crate) enum SegmentBucketResultCollector { + Range(SegmentRangeCollector), +} + +impl SegmentBucketResultCollector { + pub fn from_req(req: &BucketAggregationWithAccessor) -> crate::Result { + match &req.bucket_agg { + BucketAggregationType::Range(range_req) => Ok(Self::Range( + SegmentRangeCollector::from_req(range_req, &req.sub_aggregation, req.field_type)?, + )), + } + } + + #[inline] + pub(crate) fn collect_block( + &mut self, + doc: &[DocId], + bucket_with_accessor: &BucketAggregationWithAccessor, + force_flush: bool, + ) { + match self { + SegmentBucketResultCollector::Range(range) => { + range.collect_block(doc, bucket_with_accessor, force_flush); + } + } + } +} + +#[derive(Clone, PartialEq)] +pub(crate) struct SegmentRangeBucketEntry { + pub key: Key, + pub doc_count: u64, + pub sub_aggregation: Option, + /// The from range of the bucket. Equals f64::MIN when None. + pub from: Option, + /// The to range of the bucket. Equals f64::MAX when None. + pub to: Option, +} + +impl Debug for SegmentRangeBucketEntry { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SegmentRangeBucketEntry") + .field("key", &self.key) + .field("doc_count", &self.doc_count) + .field("from", &self.from) + .field("to", &self.to) + .finish() + } +} diff --git a/src/collector/histogram_collector.rs b/src/collector/histogram_collector.rs index 8685b4aca..be5363523 100644 --- a/src/collector/histogram_collector.rs +++ b/src/collector/histogram_collector.rs @@ -19,7 +19,7 @@ use crate::{DocId, Score}; /// /// # Warning /// -/// f64 field. are not supported. +/// f64 fields are not supported. #[derive(Clone)] pub struct HistogramCollector { min_value: u64, diff --git a/src/error.rs b/src/error.rs index 146112bd2..ba8f520cb 100644 --- a/src/error.rs +++ b/src/error.rs @@ -74,6 +74,9 @@ pub enum TantivyError { /// A thread holding the locked panicked and poisoned the lock. #[error("A thread holding the locked panicked and poisoned the lock")] Poisoned, + /// The provided field name does not exist. + #[error("The field does not exist: '{0}'")] + FieldNotFound(String), /// Invalid argument was passed by the user. #[error("An invalid argument was passed: '{0}'")] InvalidArgument(String), diff --git a/src/fastfield/bytes/mod.rs b/src/fastfield/bytes/mod.rs index a9bad4c7c..480baa536 100644 --- a/src/fastfield/bytes/mod.rs +++ b/src/fastfield/bytes/mod.rs @@ -86,7 +86,7 @@ mod tests { let field = searcher.schema().get_field("string_bytes").unwrap(); let term = Term::from_field_bytes(field, b"lucene".as_ref()); let term_query = TermQuery::new(term, IndexRecordOption::Basic); - let term_weight = term_query.specialized_weight(&searcher, true)?; + let term_weight = term_query.specialized_weight(&*searcher, true)?; let term_scorer = term_weight.specialized_scorer(searcher.segment_reader(0), 1.0)?; assert_eq!(term_scorer.doc(), 0u32); Ok(()) @@ -99,7 +99,7 @@ mod tests { let field = searcher.schema().get_field("string_bytes").unwrap(); let term = Term::from_field_bytes(field, b"lucene".as_ref()); let term_query = TermQuery::new(term, IndexRecordOption::Basic); - let term_weight_err = term_query.specialized_weight(&searcher, false); + let term_weight_err = term_query.specialized_weight(&*searcher, false); assert!(matches!( term_weight_err, Err(crate::TantivyError::SchemaError(_)) diff --git a/src/fastfield/reader.rs b/src/fastfield/reader.rs index e18245bfe..eeb6b3d9b 100644 --- a/src/fastfield/reader.rs +++ b/src/fastfield/reader.rs @@ -112,6 +112,7 @@ impl DynamicFastFieldReader { } impl FastFieldReader for DynamicFastFieldReader { + #[inline] fn get(&self, doc: DocId) -> Item { match self { Self::Bitpacked(reader) => reader.get(doc), @@ -119,6 +120,7 @@ impl FastFieldReader for DynamicFastFieldReader { Self::MultiLinearInterpol(reader) => reader.get(doc), } } + #[inline] fn get_range(&self, start: u64, output: &mut [Item]) { match self { Self::Bitpacked(reader) => reader.get_range(start, output), @@ -174,6 +176,7 @@ impl FastFieldReaderCodecWrapper Item { Item::from_u64(self.reader.get_u64(doc, self.bytes.as_slice())) } diff --git a/src/lib.rs b/src/lib.rs index 5cc600833..fc7dacc0d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -144,6 +144,7 @@ mod indexer; pub mod error; pub mod tokenizer; +pub mod aggregation; pub mod collector; pub mod directory; pub mod fastfield;