diff --git a/CHANGELOG.md b/CHANGELOG.md index 39e19f3e6..00d88d293 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,8 +8,9 @@ Unreleased - Converting a `time::OffsetDateTime` to `Value::Date` implicitly converts the value into UTC. If this is not desired do the time zone conversion yourself and use `time::PrimitiveDateTime` directly instead. -- Add [histogram](https://github.com/quickwit-oss/tantivy/pull/1306) aggregation (@PSeitz). -- Add support for fastfield on text fields (@PSeitz). +- Add [histogram](https://github.com/quickwit-oss/tantivy/pull/1306) aggregation (@PSeitz) +- Add support for fastfield on text fields (@PSeitz) +- Add terms aggregation (@PSeitz) Tantivy 0.17 ================================ diff --git a/src/aggregation/agg_req.rs b/src/aggregation/agg_req.rs index a8f8f059c..889abbbb3 100644 --- a/src/aggregation/agg_req.rs +++ b/src/aggregation/agg_req.rs @@ -48,8 +48,8 @@ use std::collections::{HashMap, HashSet}; use serde::{Deserialize, Serialize}; -use super::bucket::HistogramAggregation; pub use super::bucket::RangeAggregation; +use super::bucket::{HistogramAggregation, TermsAggregation}; use super::metric::{AverageAggregation, StatsAggregation}; use super::VecWithNames; @@ -102,8 +102,14 @@ pub(crate) struct BucketAggregationInternal { impl BucketAggregationInternal { pub(crate) fn as_histogram(&self) -> &HistogramAggregation { match &self.bucket_agg { - BucketAggregationType::Range(_) => panic!("unexpected aggregation"), BucketAggregationType::Histogram(histogram) => histogram, + _ => panic!("unexpected aggregation"), + } + } + pub(crate) fn as_term(&self) -> &TermsAggregation { + match &self.bucket_agg { + BucketAggregationType::Terms(terms) => terms, + _ => panic!("unexpected aggregation"), } } } @@ -177,11 +183,15 @@ pub enum BucketAggregationType { /// Put data into buckets of user-defined ranges. #[serde(rename = "histogram")] Histogram(HistogramAggregation), + /// Put data into buckets of terms. + #[serde(rename = "terms")] + Terms(TermsAggregation), } impl BucketAggregationType { fn get_fast_field_names(&self, fast_field_names: &mut HashSet) { match self { + BucketAggregationType::Terms(terms) => fast_field_names.insert(terms.field.to_string()), BucketAggregationType::Range(range) => fast_field_names.insert(range.field.to_string()), BucketAggregationType::Histogram(histogram) => { fast_field_names.insert(histogram.field.to_string()) diff --git a/src/aggregation/agg_req_with_accessor.rs b/src/aggregation/agg_req_with_accessor.rs index bf87e5100..4866621ad 100644 --- a/src/aggregation/agg_req_with_accessor.rs +++ b/src/aggregation/agg_req_with_accessor.rs @@ -1,12 +1,16 @@ //! This will enhance the request tree with access to the fastfield and metadata. +use std::sync::Arc; + use super::agg_req::{Aggregation, Aggregations, BucketAggregationType, MetricAggregation}; -use super::bucket::{HistogramAggregation, RangeAggregation}; +use super::bucket::{HistogramAggregation, RangeAggregation, TermsAggregation}; use super::metric::{AverageAggregation, StatsAggregation}; use super::VecWithNames; -use crate::fastfield::{type_and_cardinality, DynamicFastFieldReader, FastType}; +use crate::fastfield::{ + type_and_cardinality, DynamicFastFieldReader, FastType, MultiValuedFastFieldReader, +}; use crate::schema::{Cardinality, Type}; -use crate::{SegmentReader, TantivyError}; +use crate::{InvertedIndexReader, SegmentReader, TantivyError}; #[derive(Clone, Default)] pub(crate) struct AggregationsWithAccessor { @@ -27,11 +31,32 @@ impl AggregationsWithAccessor { } } +#[derive(Clone)] +pub(crate) enum FastFieldAccessor { + Multi(MultiValuedFastFieldReader), + Single(DynamicFastFieldReader), +} +impl FastFieldAccessor { + pub fn as_single(&self) -> &DynamicFastFieldReader { + match self { + FastFieldAccessor::Multi(_) => panic!("unexpected ff cardinality"), + FastFieldAccessor::Single(reader) => reader, + } + } + pub fn as_multi(&self) -> &MultiValuedFastFieldReader { + match self { + FastFieldAccessor::Multi(reader) => reader, + FastFieldAccessor::Single(_) => panic!("unexpected ff cardinality"), + } + } +} + #[derive(Clone)] pub struct BucketAggregationWithAccessor { /// In general there can be buckets without fast field access, e.g. buckets that are created /// based on search terms. So eventually this needs to be Option or moved. - pub(crate) accessor: DynamicFastFieldReader, + pub(crate) accessor: FastFieldAccessor, + pub(crate) inverted_index: Option>, pub(crate) field_type: Type, pub(crate) bucket_agg: BucketAggregationType, pub(crate) sub_aggregation: AggregationsWithAccessor, @@ -43,14 +68,25 @@ impl BucketAggregationWithAccessor { sub_aggregation: &Aggregations, reader: &SegmentReader, ) -> crate::Result { + let mut inverted_index = None; let (accessor, field_type) = match &bucket { BucketAggregationType::Range(RangeAggregation { field: field_name, ranges: _, - }) => get_ff_reader_and_validate(reader, field_name)?, + }) => get_ff_reader_and_validate(reader, field_name, false)?, BucketAggregationType::Histogram(HistogramAggregation { field: field_name, .. - }) => get_ff_reader_and_validate(reader, field_name)?, + }) => get_ff_reader_and_validate(reader, field_name, false)?, + BucketAggregationType::Terms(TermsAggregation { + field: field_name, .. + }) => { + let field = reader + .schema() + .get_field(field_name) + .ok_or_else(|| TantivyError::FieldNotFound(field_name.to_string()))?; + inverted_index = Some(reader.inverted_index(field)?); + get_ff_reader_and_validate(reader, field_name, true)? + } }; let sub_aggregation = sub_aggregation.clone(); Ok(BucketAggregationWithAccessor { @@ -58,6 +94,7 @@ impl BucketAggregationWithAccessor { field_type, sub_aggregation: get_aggs_with_accessor_and_validate(&sub_aggregation, reader)?, bucket_agg: bucket.clone(), + inverted_index, }) } } @@ -78,10 +115,10 @@ impl MetricAggregationWithAccessor { match &metric { MetricAggregation::Average(AverageAggregation { field: field_name }) | MetricAggregation::Stats(StatsAggregation { field: field_name }) => { - let (accessor, field_type) = get_ff_reader_and_validate(reader, field_name)?; + let (accessor, field_type) = get_ff_reader_and_validate(reader, field_name, false)?; Ok(MetricAggregationWithAccessor { - accessor, + accessor: accessor.as_single().clone(), field_type, metric: metric.clone(), }) @@ -121,7 +158,8 @@ pub(crate) fn get_aggs_with_accessor_and_validate( fn get_ff_reader_and_validate( reader: &SegmentReader, field_name: &str, -) -> crate::Result<(DynamicFastFieldReader, Type)> { + multi: bool, +) -> crate::Result<(FastFieldAccessor, Type)> { let field = reader .schema() .get_field(field_name) @@ -129,7 +167,7 @@ fn get_ff_reader_and_validate( let field_type = reader.schema().get_field_entry(field).field_type(); if let Some((ff_type, cardinality)) = type_and_cardinality(field_type) { - if cardinality == Cardinality::MultiValues || ff_type == FastType::Date { + if (!multi && cardinality == Cardinality::MultiValues) || ff_type == FastType::Date { return Err(TantivyError::InvalidArgument(format!( "Invalid field type in aggregation {:?}, only Cardinality::SingleValue supported", field_type.value_type() @@ -137,13 +175,19 @@ fn get_ff_reader_and_validate( } } else { return Err(TantivyError::InvalidArgument(format!( - "Only single value fast fields of type f64, u64, i64 are supported, but got {:?} ", + "Only fast fields of type f64, u64, i64 are supported, but got {:?} ", field_type.value_type() ))); }; let ff_fields = reader.fast_fields(); - ff_fields - .u64_lenient(field) - .map(|field| (field, field_type.value_type())) + if multi { + ff_fields + .u64s_lenient(field) + .map(|field| (FastFieldAccessor::Multi(field), field_type.value_type())) + } else { + ff_fields + .u64_lenient(field) + .map(|field| (FastFieldAccessor::Single(field), field_type.value_type())) + } } diff --git a/src/aggregation/agg_result.rs b/src/aggregation/agg_result.rs index 6132ba7cb..0cd5ec915 100644 --- a/src/aggregation/agg_result.rs +++ b/src/aggregation/agg_result.rs @@ -11,7 +11,7 @@ use itertools::Itertools; use serde::{Deserialize, Serialize}; use super::agg_req::{Aggregations, AggregationsInternal, BucketAggregationInternal}; -use super::bucket::intermediate_buckets_to_final_buckets; +use super::bucket::{intermediate_buckets_to_final_buckets, GetDocCount}; use super::intermediate_agg_result::{ IntermediateAggregationResults, IntermediateBucketResult, IntermediateHistogramBucketEntry, IntermediateMetricResult, IntermediateRangeBucketEntry, @@ -34,8 +34,8 @@ impl AggregationResults { /// Convert and intermediate result and its aggregation request to the final result /// /// Internal function, CollectorAggregations is used instead Aggregations, which is optimized - /// for internal processing - fn from_intermediate_and_req_internal( + /// for internal processing, by splitting metric and buckets into seperate groups. + pub(crate) fn from_intermediate_and_req_internal( results: IntermediateAggregationResults, req: &AggregationsInternal, ) -> Self { @@ -140,6 +140,18 @@ pub enum BucketResult { /// See [HistogramAggregation](super::bucket::HistogramAggregation) buckets: Vec, }, + /// This is the term result + Terms { + /// The buckets. + /// + /// See [TermsAggregation](super::bucket::TermsAggregation) + buckets: Vec, + /// The number of documents that didn’t make it into to TOP N due to shard_size or size + sum_other_doc_count: u64, + #[serde(skip_serializing_if = "Option::is_none")] + /// The upper bound error for the doc count of each term. + doc_count_error_upper_bound: Option, + }, } impl BucketResult { @@ -173,6 +185,9 @@ impl BucketResult { BucketResult::Histogram { buckets } } + IntermediateBucketResult::Terms(terms) => { + terms.into_final_result(req.as_term(), &req.sub_aggregation) + } } } } @@ -229,6 +244,11 @@ impl BucketEntry { } } } +impl GetDocCount for BucketEntry { + fn doc_count(&self) -> u64 { + self.doc_count + } +} /// This is the range entry for a bucket, which contains a key, count, and optionally /// sub_aggregations. diff --git a/src/aggregation/bucket/histogram/histogram.rs b/src/aggregation/bucket/histogram/histogram.rs index 4bac22505..7f57ed178 100644 --- a/src/aggregation/bucket/histogram/histogram.rs +++ b/src/aggregation/bucket/histogram/histogram.rs @@ -13,9 +13,7 @@ use crate::aggregation::f64_from_fastfield_u64; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResults, IntermediateBucketResult, IntermediateHistogramBucketEntry, }; -use crate::aggregation::segment_agg_result::{ - SegmentAggregationResultsCollector, SegmentHistogramBucketEntry, -}; +use crate::aggregation::segment_agg_result::SegmentAggregationResultsCollector; use crate::fastfield::{DynamicFastFieldReader, FastFieldReader}; use crate::schema::Type; use crate::{DocId, TantivyError}; @@ -159,6 +157,27 @@ impl HistogramBounds { } } +#[derive(Clone, Debug, PartialEq)] +pub(crate) struct SegmentHistogramBucketEntry { + pub key: f64, + pub doc_count: u64, +} + +impl SegmentHistogramBucketEntry { + pub(crate) fn into_intermediate_bucket_entry( + self, + sub_aggregation: SegmentAggregationResultsCollector, + agg_with_accessor: &AggregationsWithAccessor, + ) -> IntermediateHistogramBucketEntry { + IntermediateHistogramBucketEntry { + key: self.key, + doc_count: self.doc_count, + sub_aggregation: sub_aggregation + .into_intermediate_aggregations_result(agg_with_accessor), + } + } +} + /// The collector puts values from the fast field into the correct buckets and does a conversion to /// the correct datatype. #[derive(Clone, Debug, PartialEq)] @@ -174,7 +193,10 @@ pub struct SegmentHistogramCollector { } impl SegmentHistogramCollector { - pub fn into_intermediate_bucket_result(self) -> IntermediateBucketResult { + pub fn into_intermediate_bucket_result( + self, + agg_with_accessor: &BucketAggregationWithAccessor, + ) -> IntermediateBucketResult { let mut buckets = Vec::with_capacity( self.buckets .iter() @@ -193,7 +215,12 @@ impl SegmentHistogramCollector { .into_iter() .zip(sub_aggregations.into_iter()) .filter(|(bucket, _sub_aggregation)| bucket.doc_count != 0) - .map(|(bucket, sub_aggregation)| (bucket, sub_aggregation).into()), + .map(|(bucket, sub_aggregation)| { + bucket.into_intermediate_bucket_entry( + sub_aggregation, + &agg_with_accessor.sub_aggregation, + ) + }), ) } else { buckets.extend( @@ -273,12 +300,13 @@ impl SegmentHistogramCollector { let get_bucket_num = |val| (get_bucket_num_f64(val, interval, offset) as i64 - first_bucket_num) as usize; + let accessor = bucket_with_accessor.accessor.as_single(); let mut iter = doc.chunks_exact(4); for docs in iter.by_ref() { - let val0 = self.f64_from_fastfield_u64(bucket_with_accessor.accessor.get(docs[0])); - let val1 = self.f64_from_fastfield_u64(bucket_with_accessor.accessor.get(docs[1])); - let val2 = self.f64_from_fastfield_u64(bucket_with_accessor.accessor.get(docs[2])); - let val3 = self.f64_from_fastfield_u64(bucket_with_accessor.accessor.get(docs[3])); + let val0 = self.f64_from_fastfield_u64(accessor.get(docs[0])); + let val1 = self.f64_from_fastfield_u64(accessor.get(docs[1])); + let val2 = self.f64_from_fastfield_u64(accessor.get(docs[2])); + let val3 = self.f64_from_fastfield_u64(accessor.get(docs[3])); let bucket_pos0 = get_bucket_num(val0); let bucket_pos1 = get_bucket_num(val1); @@ -315,8 +343,7 @@ impl SegmentHistogramCollector { ); } for doc in iter.remainder() { - let val = - f64_from_fastfield_u64(bucket_with_accessor.accessor.get(*doc), &self.field_type); + let val = f64_from_fastfield_u64(accessor.get(*doc), &self.field_type); if !bounds.contains(val) { continue; } @@ -630,41 +657,9 @@ mod tests { }; use crate::aggregation::metric::{AverageAggregation, StatsAggregation}; use crate::aggregation::tests::{ - get_test_index_2_segments, get_test_index_from_values, get_test_index_with_num_docs, + exec_request, exec_request_with_query, get_test_index_2_segments, + get_test_index_from_values, get_test_index_with_num_docs, }; - use crate::aggregation::AggregationCollector; - use crate::query::{AllQuery, TermQuery}; - use crate::schema::IndexRecordOption; - use crate::{Index, Term}; - - fn exec_request(agg_req: Aggregations, index: &Index) -> crate::Result { - exec_request_with_query(agg_req, index, None) - } - fn exec_request_with_query( - agg_req: Aggregations, - index: &Index, - query: Option<(&str, &str)>, - ) -> crate::Result { - let collector = AggregationCollector::from_aggs(agg_req); - - let reader = index.reader()?; - let searcher = reader.searcher(); - let agg_res = if let Some((field, term)) = query { - let text_field = reader.searcher().schema().get_field(field).unwrap(); - - let term_query = TermQuery::new( - Term::from_field_text(text_field, term), - IndexRecordOption::Basic, - ); - - searcher.search(&term_query, &collector)? - } else { - searcher.search(&AllQuery, &collector)? - }; - - let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?; - Ok(res) - } #[test] fn histogram_test_crooked_values() -> crate::Result<()> { diff --git a/src/aggregation/bucket/mod.rs b/src/aggregation/bucket/mod.rs index 0a9991fce..dd8a7f270 100644 --- a/src/aggregation/bucket/mod.rs +++ b/src/aggregation/bucket/mod.rs @@ -9,8 +9,10 @@ mod histogram; mod range; +mod term_agg; pub(crate) use histogram::SegmentHistogramCollector; pub use histogram::*; pub(crate) use range::SegmentRangeCollector; pub use range::*; +pub use term_agg::*; diff --git a/src/aggregation/bucket/range.rs b/src/aggregation/bucket/range.rs index 2715993b3..47be0ee0f 100644 --- a/src/aggregation/bucket/range.rs +++ b/src/aggregation/bucket/range.rs @@ -1,3 +1,4 @@ +use std::fmt::Debug; use std::ops::Range; use serde::{Deserialize, Serialize}; @@ -5,10 +6,10 @@ use serde::{Deserialize, Serialize}; use crate::aggregation::agg_req_with_accessor::{ AggregationsWithAccessor, BucketAggregationWithAccessor, }; -use crate::aggregation::intermediate_agg_result::IntermediateBucketResult; -use crate::aggregation::segment_agg_result::{ - SegmentAggregationResultsCollector, SegmentRangeBucketEntry, +use crate::aggregation::intermediate_agg_result::{ + IntermediateBucketResult, IntermediateRangeBucketEntry, }; +use crate::aggregation::segment_agg_result::SegmentAggregationResultsCollector; use crate::aggregation::{f64_from_fastfield_u64, f64_to_fastfield_u64, Key}; use crate::fastfield::FastFieldReader; use crate::schema::Type; @@ -102,8 +103,53 @@ pub struct SegmentRangeCollector { field_type: Type, } +#[derive(Clone, PartialEq)] +pub(crate) struct SegmentRangeBucketEntry { + pub key: Key, + pub doc_count: u64, + pub sub_aggregation: Option, + /// The from range of the bucket. Equals f64::MIN when None. + pub from: Option, + /// The to range of the bucket. Equals f64::MAX when None. + pub to: Option, +} + +impl Debug for SegmentRangeBucketEntry { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SegmentRangeBucketEntry") + .field("key", &self.key) + .field("doc_count", &self.doc_count) + .field("from", &self.from) + .field("to", &self.to) + .finish() + } +} +impl SegmentRangeBucketEntry { + pub(crate) fn into_intermediate_bucket_entry( + self, + agg_with_accessor: &AggregationsWithAccessor, + ) -> IntermediateRangeBucketEntry { + let sub_aggregation = if let Some(sub_aggregation) = self.sub_aggregation { + sub_aggregation.into_intermediate_aggregations_result(agg_with_accessor) + } else { + Default::default() + }; + + IntermediateRangeBucketEntry { + key: self.key, + doc_count: self.doc_count, + sub_aggregation, + from: self.from, + to: self.to, + } + } +} + impl SegmentRangeCollector { - pub fn into_intermediate_bucket_result(self) -> IntermediateBucketResult { + pub fn into_intermediate_bucket_result( + self, + agg_with_accessor: &BucketAggregationWithAccessor, + ) -> IntermediateBucketResult { let field_type = self.field_type; let buckets = self @@ -112,7 +158,9 @@ impl SegmentRangeCollector { .map(move |range_bucket| { ( range_to_string(&range_bucket.range, &field_type), - range_bucket.bucket.into(), + range_bucket + .bucket + .into_intermediate_bucket_entry(&agg_with_accessor.sub_aggregation), ) }) .collect(); @@ -175,11 +223,12 @@ impl SegmentRangeCollector { force_flush: bool, ) { let mut iter = doc.chunks_exact(4); + let accessor = bucket_with_accessor.accessor.as_single(); for docs in iter.by_ref() { - let val1 = bucket_with_accessor.accessor.get(docs[0]); - let val2 = bucket_with_accessor.accessor.get(docs[1]); - let val3 = bucket_with_accessor.accessor.get(docs[2]); - let val4 = bucket_with_accessor.accessor.get(docs[3]); + let val1 = accessor.get(docs[0]); + let val2 = accessor.get(docs[1]); + let val3 = accessor.get(docs[2]); + let val4 = accessor.get(docs[3]); let bucket_pos1 = self.get_bucket_pos(val1); let bucket_pos2 = self.get_bucket_pos(val2); let bucket_pos3 = self.get_bucket_pos(val3); @@ -191,7 +240,7 @@ impl SegmentRangeCollector { self.increment_bucket(bucket_pos4, docs[3], &bucket_with_accessor.sub_aggregation); } for doc in iter.remainder() { - let val = bucket_with_accessor.accessor.get(*doc); + let val = accessor.get(*doc); let bucket_pos = self.get_bucket_pos(val); self.increment_bucket(bucket_pos, *doc, &bucket_with_accessor.sub_aggregation); } @@ -487,11 +536,7 @@ mod tests { #[test] fn range_binary_search_test_f64() { - let ranges = vec![ - //(f64::MIN..10.0).into(), - (10.0..100.0).into(), - //(100.0..f64::MAX).into(), - ]; + let ranges = vec![(10.0..100.0).into()]; let collector = get_collector_from_ranges(ranges, Type::F64); let search = |val: u64| collector.get_bucket_pos(val); diff --git a/src/aggregation/bucket/term_agg.rs b/src/aggregation/bucket/term_agg.rs new file mode 100644 index 000000000..5ac015cd9 --- /dev/null +++ b/src/aggregation/bucket/term_agg.rs @@ -0,0 +1,602 @@ +use std::fmt::Debug; + +use fnv::FnvHashMap; +use serde::{Deserialize, Serialize}; + +use crate::aggregation::agg_req_with_accessor::{ + AggregationsWithAccessor, BucketAggregationWithAccessor, +}; +use crate::aggregation::intermediate_agg_result::{ + IntermediateBucketResult, IntermediateTermBucketEntry, IntermediateTermBucketResult, +}; +use crate::aggregation::segment_agg_result::SegmentAggregationResultsCollector; +use crate::fastfield::MultiValuedFastFieldReader; +use crate::schema::Type; +use crate::DocId; + +/// Creates one bucket for every unique term +/// +/// ### Terminology +/// Shard and Segment are equivalent. +/// +/// ## Document count error +/// To improve performance, results from one segment are cut off at `shard_size`. On a single +/// segment this is fine. When combining results of multiple segments, terms that +/// don't make it in the top n of a shard increase the theoretical upper bound error by lowest +/// term-count. +/// +/// Even with a larger `shard_size` value, doc_count values for a terms aggregation may be +/// approximate. As a result, any sub-aggregations on the terms aggregation may also be approximate. +/// sum_other_doc_count is the number of documents that didn’t make it into the the top size terms. +/// If this is greater than 0, you can be sure that the terms agg had to throw away some buckets, +/// either because they didn’t fit into size on the root node or they didn’t fit into +/// shard_size on the leaf node. +/// +/// ## Per bucket document count error +/// If you set the show_term_doc_count_error parameter to true, the terms aggregation will include +/// doc_count_error_upper_bound, which is an upper bound to the error on the doc_count returned by +/// each shard. It’s the sum of the size of the largest bucket on each shard that didn’t fit into +/// shard_size. +/// +/// Result type is [BucketResult](crate::aggregation::agg_result::BucketResult) with +/// [RangeBucketEntry](crate::aggregation::agg_result::RangeBucketEntry) on the +/// AggregationCollector. +/// +/// Result type is +/// [crate::aggregation::intermediate_agg_result::IntermediateBucketResult] with +/// [crate::aggregation::intermediate_agg_result::IntermediateRangeBucketEntry] on the +/// DistributedAggregationCollector. +/// +/// # Limitations/Compatibility +/// +/// # Request JSON Format +/// ```json +/// { +/// "genres": { +/// "field": "genre", +/// } +/// } +/// ``` +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct TermsAggregation { + /// The field to aggregate on. + pub field: String, + /// By default, the top 10 terms with the most documents are returned. + /// Larger values for size are more expensive. + pub size: Option, + + /// The get more accurate results, we fetch more than `size` from each segment. + /// By default we fetch `shard_size` terms, which defaults to size * 1.5 + 10. + pub shard_size: Option, + + /// If you set the `show_term_doc_count_error` parameter to true, the terms aggregation will + /// include doc_count_error_upper_bound, which is an upper bound to the error on the + /// doc_count returned by each shard. It’s the sum of the size of the largest bucket on + /// each segment that didn’t fit into `shard_size`. + #[serde(default = "default_show_term_doc_count_error")] + pub show_term_doc_count_error: bool, + + /// Filter all terms than are lower `min_doc_count`. + pub min_doc_count: Option, +} +impl Default for TermsAggregation { + fn default() -> Self { + Self { + field: Default::default(), + size: Default::default(), + shard_size: Default::default(), + show_term_doc_count_error: true, + min_doc_count: Default::default(), + } + } +} + +fn default_show_term_doc_count_error() -> bool { + true +} + +/// Same as TermsAggregation, but with populated defaults. +#[derive(Clone, Debug, PartialEq)] +pub(crate) struct TermsAggregationInternal { + /// The field to aggregate on. + pub field: String, + /// By default, the top 10 terms with the most documents are returned. + /// Larger values for size are more expensive. + pub size: u32, + + /// The get more accurate results, we fetch more than `size` from each segment. + /// By default we fetch `shard_size` terms, which defaults to size * 1.5 + 10. + /// + /// Cannot be smaller than size. In that case it will be set automatically to size. + pub shard_size: u32, + + /// If you set the `show_term_doc_count_error` parameter to true, the terms aggregation will + /// include doc_count_error_upper_bound, which is an upper bound to the error on the + /// doc_count returned by each shard. It’s the sum of the size of the largest bucket on + /// each segment that didn’t fit into `shard_size`. + pub show_term_doc_count_error: bool, + + /// Filter all terms than are lower `min_doc_count`. + pub min_doc_count: Option, +} + +impl TermsAggregationInternal { + pub(crate) fn from_req(req: &TermsAggregation) -> Self { + let size = req.size.unwrap_or(10); + + let mut shard_size = req + .shard_size + .unwrap_or((size as f32 * 1.5_f32) as u32 + 10); + + shard_size = shard_size.max(size); + TermsAggregationInternal { + field: req.field.to_string(), + size, + shard_size, + show_term_doc_count_error: req.show_term_doc_count_error, + min_doc_count: req.min_doc_count, + } + } +} + +const TERM_BUCKET_SIZE: usize = 100; +#[derive(Clone, Debug, PartialEq)] +/// Chunks the term_id value range in TERM_BUCKET_SIZE blocks. +struct TermBuckets { + pub(crate) entries: FnvHashMap, + blueprint: Option, +} + +#[derive(Clone, PartialEq, Default)] +struct TermBucketEntry { + doc_count: u64, + sub_aggregations: Option, +} + +impl Debug for TermBucketEntry { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TermBucketEntry") + .field("doc_count", &self.doc_count) + .finish() + } +} + +impl TermBucketEntry { + fn from_blueprint(blueprint: &Option) -> Self { + Self { + doc_count: 0, + sub_aggregations: blueprint.clone(), + } + } + + pub(crate) fn into_intermediate_bucket_entry( + self, + agg_with_accessor: &AggregationsWithAccessor, + ) -> IntermediateTermBucketEntry { + let sub_aggregation = if let Some(sub_aggregation) = self.sub_aggregations { + sub_aggregation.into_intermediate_aggregations_result(agg_with_accessor) + } else { + Default::default() + }; + + IntermediateTermBucketEntry { + doc_count: self.doc_count, + sub_aggregation, + } + } +} + +impl TermBuckets { + pub(crate) fn from_req_and_validate( + sub_aggregation: &AggregationsWithAccessor, + max_term_id: usize, + ) -> crate::Result { + let has_sub_aggregations = sub_aggregation.is_empty(); + let _num_chunks = (max_term_id / TERM_BUCKET_SIZE) + 1; + + let blueprint = if has_sub_aggregations { + let sub_aggregation = + SegmentAggregationResultsCollector::from_req_and_validate(sub_aggregation)?; + Some(sub_aggregation) + } else { + None + }; + + Ok(TermBuckets { + blueprint, + entries: Default::default(), + }) + } + + fn increment_bucket( + &mut self, + term_ids: &[u64], + doc: DocId, + bucket_with_accessor: &AggregationsWithAccessor, + blueprint: &Option, + ) { + // self.ensure_vec_exists(term_ids); + for &term_id in term_ids { + let entry = self + .entries + .entry(term_id as u32) + .or_insert_with(|| TermBucketEntry::from_blueprint(blueprint)); + entry.doc_count += 1; + if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() { + sub_aggregations.collect(doc, bucket_with_accessor); + } + } + } + + fn force_flush(&mut self, agg_with_accessor: &AggregationsWithAccessor) { + for entry in &mut self.entries.values_mut() { + if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() { + sub_aggregations.flush_staged_docs(agg_with_accessor, false); + } + } + } +} + +/// The collector puts values from the fast field into the correct buckets and does a conversion to +/// the correct datatype. +#[derive(Clone, Debug, PartialEq)] +pub struct SegmentTermCollector { + /// The buckets containing the aggregation data. + term_buckets: TermBuckets, + req: TermsAggregationInternal, + field_type: Type, + blueprint: Option, +} + +impl SegmentTermCollector { + pub(crate) fn from_req_and_validate( + req: &TermsAggregation, + sub_aggregations: &AggregationsWithAccessor, + field_type: Type, + accessor: &MultiValuedFastFieldReader, + ) -> crate::Result { + let max_term_id = accessor.max_value(); + let term_buckets = + TermBuckets::from_req_and_validate(sub_aggregations, max_term_id as usize)?; + + let has_sub_aggregations = sub_aggregations.is_empty(); + let blueprint = if has_sub_aggregations { + let sub_aggregation = + SegmentAggregationResultsCollector::from_req_and_validate(sub_aggregations)?; + Some(sub_aggregation) + } else { + None + }; + + Ok(SegmentTermCollector { + req: TermsAggregationInternal::from_req(req), + term_buckets, + field_type, + blueprint, + }) + } + + pub(crate) fn into_intermediate_bucket_result( + self, + agg_with_accessor: &BucketAggregationWithAccessor, + ) -> IntermediateBucketResult { + let mut entries: Vec<_> = self.term_buckets.entries.into_iter().collect(); + + let (term_doc_count_before_cutoff, sum_other_doc_count) = + cut_off_buckets(&mut entries, self.req.shard_size as usize); + + let inverted_index = agg_with_accessor + .inverted_index + .as_ref() + .expect("internal error: inverted index not loaded for term aggregation"); + let term_dict = inverted_index.terms(); + + let mut dict: FnvHashMap = Default::default(); + let mut buffer = vec![]; + for (term_id, entry) in entries { + term_dict + .ord_to_term(term_id as u64, &mut buffer) + .expect("could not find term"); + dict.insert( + String::from_utf8(buffer.to_vec()).unwrap(), + entry.into_intermediate_bucket_entry(&agg_with_accessor.sub_aggregation), + ); + } + IntermediateBucketResult::Terms(IntermediateTermBucketResult { + entries: dict, + sum_other_doc_count, + doc_count_error_upper_bound: term_doc_count_before_cutoff, + }) + } + + #[inline] + pub(crate) fn collect_block( + &mut self, + doc: &[DocId], + bucket_with_accessor: &BucketAggregationWithAccessor, + force_flush: bool, + ) { + let accessor = bucket_with_accessor.accessor.as_multi(); + let mut iter = doc.chunks_exact(4); + let mut vals1 = vec![]; + let mut vals2 = vec![]; + let mut vals3 = vec![]; + let mut vals4 = vec![]; + for docs in iter.by_ref() { + accessor.get_vals(docs[0], &mut vals1); + accessor.get_vals(docs[1], &mut vals2); + accessor.get_vals(docs[2], &mut vals3); + accessor.get_vals(docs[3], &mut vals4); + + self.term_buckets.increment_bucket( + &vals1, + docs[0], + &bucket_with_accessor.sub_aggregation, + &self.blueprint, + ); + self.term_buckets.increment_bucket( + &vals2, + docs[1], + &bucket_with_accessor.sub_aggregation, + &self.blueprint, + ); + self.term_buckets.increment_bucket( + &vals3, + docs[2], + &bucket_with_accessor.sub_aggregation, + &self.blueprint, + ); + self.term_buckets.increment_bucket( + &vals4, + docs[3], + &bucket_with_accessor.sub_aggregation, + &self.blueprint, + ); + } + for &doc in iter.remainder() { + accessor.get_vals(doc, &mut vals1); + + self.term_buckets.increment_bucket( + &vals1, + doc, + &bucket_with_accessor.sub_aggregation, + &self.blueprint, + ); + } + if force_flush { + self.term_buckets + .force_flush(&bucket_with_accessor.sub_aggregation); + } + } +} + +pub(crate) trait GetDocCount { + fn doc_count(&self) -> u64; +} +impl GetDocCount for (u32, TermBucketEntry) { + fn doc_count(&self) -> u64 { + self.1.doc_count + } +} + +pub(crate) fn cut_off_buckets( + entries: &mut Vec, + num_elem: usize, +) -> (u64, u64) { + entries.sort_unstable_by_key(|entry| std::cmp::Reverse(entry.doc_count())); + + let term_doc_count_before_cutoff = entries + .get(num_elem) + .map(|entry| entry.doc_count()) + .unwrap_or(0); + + let sum_other_doc_count = entries + .get(num_elem..) + .map(|cut_off_range| cut_off_range.iter().map(|entry| entry.doc_count()).sum()) + .unwrap_or(0); + + entries.truncate(num_elem); + (term_doc_count_before_cutoff, sum_other_doc_count) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::aggregation::agg_req::{ + Aggregation, Aggregations, BucketAggregation, BucketAggregationType, + }; + use crate::aggregation::tests::{exec_request, get_test_index_from_terms}; + + #[test] + fn terms_aggregation_test_single_segment() -> crate::Result<()> { + terms_aggregation_test_merge_segment(true) + } + #[test] + fn terms_aggregation_test() -> crate::Result<()> { + terms_aggregation_test_merge_segment(false) + } + fn terms_aggregation_test_merge_segment(merge_segments: bool) -> crate::Result<()> { + let segment_and_terms = vec![ + vec!["terma"], + vec!["termb"], + vec!["termc"], + vec!["terma"], + vec!["terma"], + vec!["terma"], + vec!["termb"], + vec!["terma"], + ]; + let index = get_test_index_from_terms(merge_segments, &segment_and_terms)?; + + let agg_req: Aggregations = vec![( + "my_texts".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Terms(TermsAggregation { + field: "string_id".to_string(), + ..Default::default() + }), + sub_aggregation: Default::default(), + }), + )] + .into_iter() + .collect(); + + let res = exec_request(agg_req, &index)?; + assert_eq!(res["my_texts"]["buckets"][0]["key"], "terma"); + assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 5); + assert_eq!(res["my_texts"]["buckets"][1]["key"], "termb"); + assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 2); + assert_eq!(res["my_texts"]["buckets"][2]["doc_count"], 1); + assert_eq!(res["my_texts"]["buckets"][2]["key"], "termc"); + assert_eq!(res["my_texts"]["sum_other_doc_count"], 0); + + let agg_req: Aggregations = vec![( + "my_texts".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Terms(TermsAggregation { + field: "string_id".to_string(), + size: Some(2), + shard_size: Some(2), + ..Default::default() + }), + sub_aggregation: Default::default(), + }), + )] + .into_iter() + .collect(); + + let res = exec_request(agg_req, &index)?; + assert_eq!(res["my_texts"]["buckets"][0]["key"], "terma"); + assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 5); + assert_eq!(res["my_texts"]["buckets"][1]["key"], "termb"); + assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 2); + assert_eq!( + res["my_texts"]["buckets"][2]["key"], + serde_json::Value::Null + ); + assert_eq!(res["my_texts"]["sum_other_doc_count"], 1); + + Ok(()) + } + + #[test] + fn terms_aggregation_error_count_test() -> crate::Result<()> { + let terms_per_segment = vec![ + vec!["terma", "terma", "termb", "termb", "termb", "termc"], /* termc doesn't make it + * from this segment */ + vec!["terma", "terma", "termb", "termc", "termc"], /* termb doesn't make it from + * this segment */ + ]; + + let index = get_test_index_from_terms(false, &terms_per_segment)?; + + let agg_req: Aggregations = vec![( + "my_texts".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Terms(TermsAggregation { + field: "string_id".to_string(), + size: Some(2), + shard_size: Some(2), + ..Default::default() + }), + sub_aggregation: Default::default(), + }), + )] + .into_iter() + .collect(); + + let res = exec_request(agg_req, &index)?; + println!("{}", &serde_json::to_string_pretty(&res).unwrap()); + + assert_eq!(res["my_texts"]["buckets"][0]["key"], "terma"); + assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 4); + assert_eq!(res["my_texts"]["buckets"][1]["key"], "termb"); + assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 3); + assert_eq!( + res["my_texts"]["buckets"][2]["doc_count"], + serde_json::Value::Null + ); + assert_eq!(res["my_texts"]["sum_other_doc_count"], 4); + assert_eq!(res["my_texts"]["doc_count_error_upper_bound"], 2); + + Ok(()) + } +} + +#[cfg(all(test, feature = "unstable"))] +mod bench { + + use fnv::FnvHashMap; + use itertools::Itertools; + use rand::seq::SliceRandom; + use rand::thread_rng; + + use super::*; + + fn get_collector_with_buckets(num_docs: u64) -> TermBuckets { + TermBuckets::from_req_and_validate(&Default::default(), num_docs as usize).unwrap() + } + + fn get_rand_terms(total_terms: u64, num_terms_returned: u64) -> Vec { + let mut rng = thread_rng(); + + let all_terms = (0..total_terms - 1).collect_vec(); + + let mut vals = vec![]; + for _ in 0..num_terms_returned { + let val = all_terms.as_slice().choose(&mut rng).unwrap(); + vals.push(*val); + } + + vals + } + + fn bench_term_hashmap(b: &mut test::Bencher, num_terms: u64, total_terms: u64) { + let mut collector = FnvHashMap::default(); + let vals = get_rand_terms(total_terms, num_terms); + b.iter(|| { + for val in &vals { + let val = collector.entry(val).or_insert(TermBucketEntry::default()); + val.doc_count += 1; + } + collector.get(&0).cloned() + }) + } + fn bench_term_buckets(b: &mut test::Bencher, num_terms: u64, total_terms: u64) { + let mut collector = get_collector_with_buckets(total_terms); + let vals = get_rand_terms(total_terms, num_terms); + let aggregations_with_accessor: AggregationsWithAccessor = Default::default(); + b.iter(|| { + for &val in &vals { + collector.increment_bucket(&[val], 0, &aggregations_with_accessor, &None); + } + }) + } + + #[bench] + fn bench_term_buckets_500_of_1_000_000(b: &mut test::Bencher) { + bench_term_buckets(b, 500u64, 1_000_000u64) + } + #[bench] + fn bench_fnv_buckets_500_of_1_000_000(b: &mut test::Bencher) { + bench_term_hashmap(b, 500u64, 1_000_000u64) + } + + #[bench] + fn bench_term_buckets_1_000_000_of_50_000(b: &mut test::Bencher) { + bench_term_buckets(b, 1_000_000u64, 50_000u64) + } + #[bench] + fn bench_fnv_buckets_1_000_000_of_50_000(b: &mut test::Bencher) { + bench_term_hashmap(b, 1_000_000u64, 50_000u64) + } + + #[bench] + fn bench_term_buckets_1_000_000_of_50(b: &mut test::Bencher) { + bench_term_buckets(b, 1_000_000u64, 50u64) + } + #[bench] + fn bench_fnv_buckets_1_000_000_of_50(b: &mut test::Bencher) { + bench_term_hashmap(b, 1_000_000u64, 50u64) + } +} diff --git a/src/aggregation/collector.rs b/src/aggregation/collector.rs index 227836151..2191f8b7c 100644 --- a/src/aggregation/collector.rs +++ b/src/aggregation/collector.rs @@ -106,7 +106,7 @@ fn merge_fruits( /// AggregationSegmentCollector does the aggregation collection on a segment. pub struct AggregationSegmentCollector { - aggs: AggregationsWithAccessor, + aggs_with_accessor: AggregationsWithAccessor, result: SegmentAggregationResultsCollector, } @@ -121,7 +121,7 @@ impl AggregationSegmentCollector { let result = SegmentAggregationResultsCollector::from_req_and_validate(&aggs_with_accessor)?; Ok(AggregationSegmentCollector { - aggs: aggs_with_accessor, + aggs_with_accessor, result, }) } @@ -132,11 +132,13 @@ impl SegmentCollector for AggregationSegmentCollector { #[inline] fn collect(&mut self, doc: crate::DocId, _score: crate::Score) { - self.result.collect(doc, &self.aggs); + self.result.collect(doc, &self.aggs_with_accessor); } fn harvest(mut self) -> Self::Fruit { - self.result.flush_staged_docs(&self.aggs, true); - self.result.into() + self.result + .flush_staged_docs(&self.aggs_with_accessor, true); + self.result + .into_intermediate_aggregations_result(&self.aggs_with_accessor) } } diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index 6c577944b..8cd137ba5 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -9,12 +9,13 @@ use itertools::Itertools; use serde::{Deserialize, Serialize}; use super::agg_req::{AggregationsInternal, BucketAggregationType, MetricAggregation}; +use super::agg_result::BucketResult; +use super::bucket::{cut_off_buckets, SegmentHistogramBucketEntry, TermsAggregation}; use super::metric::{IntermediateAverage, IntermediateStats}; -use super::segment_agg_result::{ - SegmentAggregationResultsCollector, SegmentBucketResultCollector, SegmentHistogramBucketEntry, - SegmentMetricResultCollector, SegmentRangeBucketEntry, -}; +use super::segment_agg_result::SegmentMetricResultCollector; use super::{Key, SerializedKey, VecWithNames}; +use crate::aggregation::agg_result::{AggregationResults, BucketEntry}; +use crate::aggregation::bucket::TermsAggregationInternal; /// Contains the intermediate aggregation result, which is optimized to be merged with other /// intermediate results. @@ -24,15 +25,6 @@ pub struct IntermediateAggregationResults { pub(crate) buckets: Option>, } -impl From for IntermediateAggregationResults { - fn from(tree: SegmentAggregationResultsCollector) -> Self { - let metrics = tree.metrics.map(VecWithNames::from_other); - let buckets = tree.buckets.map(VecWithNames::from_other); - - Self { metrics, buckets } - } -} - impl IntermediateAggregationResults { pub(crate) fn empty_from_req(req: &AggregationsInternal) -> Self { let metrics = if req.metrics.is_empty() { @@ -169,22 +161,14 @@ pub enum IntermediateBucketResult { /// The buckets buckets: Vec, }, -} - -impl From for IntermediateBucketResult { - fn from(collector: SegmentBucketResultCollector) -> Self { - match collector { - SegmentBucketResultCollector::Range(range) => range.into_intermediate_bucket_result(), - SegmentBucketResultCollector::Histogram(histogram) => { - histogram.into_intermediate_bucket_result() - } - } - } + /// Term aggregation + Terms(IntermediateTermBucketResult), } impl IntermediateBucketResult { pub(crate) fn empty_from_req(req: &BucketAggregationType) -> Self { match req { + BucketAggregationType::Terms(_) => IntermediateBucketResult::Terms(Default::default()), BucketAggregationType::Range(_) => IntermediateBucketResult::Range(Default::default()), BucketAggregationType::Histogram(_) => { IntermediateBucketResult::Histogram { buckets: vec![] } @@ -193,6 +177,16 @@ impl IntermediateBucketResult { } fn merge_fruits(&mut self, other: IntermediateBucketResult) { match (self, other) { + ( + IntermediateBucketResult::Terms(entries_left), + IntermediateBucketResult::Terms(entries_right), + ) => { + merge_maps(&mut entries_left.entries, entries_right.entries); + entries_left.sum_other_doc_count += entries_right.sum_other_doc_count; + entries_left.doc_count_error_upper_bound += + entries_right.doc_count_error_upper_bound; + } + ( IntermediateBucketResult::Range(entries_left), IntermediateBucketResult::Range(entries_right), @@ -232,6 +226,59 @@ impl IntermediateBucketResult { (IntermediateBucketResult::Histogram { .. }, _) => { panic!("try merge on different types") } + (IntermediateBucketResult::Terms { .. }, _) => { + panic!("try merge on different types") + } + } + } +} + +#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)] +/// Term aggregation including error counts +pub struct IntermediateTermBucketResult { + pub(crate) entries: FnvHashMap, + pub(crate) sum_other_doc_count: u64, + pub(crate) doc_count_error_upper_bound: u64, +} + +impl IntermediateTermBucketResult { + pub(crate) fn into_final_result( + self, + req: &TermsAggregation, + sub_aggregation_req: &AggregationsInternal, + ) -> BucketResult { + let req = TermsAggregationInternal::from_req(req); + let mut buckets: Vec = self + .entries + .into_iter() + .map(|(key, entry)| BucketEntry { + key: Key::Str(key), + doc_count: entry.doc_count, + sub_aggregation: AggregationResults::from_intermediate_and_req_internal( + entry.sub_aggregation, + sub_aggregation_req, + ), + }) + .collect(); + buckets.sort_unstable_by_key(|bucket| std::cmp::Reverse(bucket.doc_count)); + // We ignore _term_doc_count_before_cutoff here, because it increases the upperbound error + // only for terms that didn't make it into the top N. + // + // This can be interesting, as a value of quality of the results, but not good to check the + // actual error count for the returned terms. + let (_term_doc_count_before_cutoff, sum_other_doc_count) = + cut_off_buckets(&mut buckets, req.size as usize); + + let doc_count_error_upper_bound = if req.show_term_doc_count_error { + Some(self.doc_count_error_upper_bound) + } else { + None + }; + + BucketResult::Terms { + buckets, + sum_other_doc_count: self.sum_other_doc_count + sum_other_doc_count, + doc_count_error_upper_bound, } } } @@ -277,26 +324,6 @@ impl From for IntermediateHistogramBucketEntry { } } -impl - From<( - SegmentHistogramBucketEntry, - SegmentAggregationResultsCollector, - )> for IntermediateHistogramBucketEntry -{ - fn from( - entry: ( - SegmentHistogramBucketEntry, - SegmentAggregationResultsCollector, - ), - ) -> Self { - IntermediateHistogramBucketEntry { - key: entry.0.key, - doc_count: entry.0.doc_count, - sub_aggregation: entry.1.into(), - } - } -} - /// This is the range entry for a bucket, which contains a key, count, and optionally /// sub_aggregations. #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] @@ -305,7 +332,6 @@ pub struct IntermediateRangeBucketEntry { pub key: Key, /// The number of documents in the bucket. pub doc_count: u64, - pub(crate) values: Option>, /// The sub_aggregation in this bucket. pub sub_aggregation: IntermediateAggregationResults, /// The from range of the bucket. Equals f64::MIN when None. @@ -316,22 +342,20 @@ pub struct IntermediateRangeBucketEntry { pub to: Option, } -impl From for IntermediateRangeBucketEntry { - fn from(entry: SegmentRangeBucketEntry) -> Self { - let sub_aggregation = if let Some(sub_aggregation) = entry.sub_aggregation { - sub_aggregation.into() - } else { - Default::default() - }; +/// This is the term entry for a bucket, which contains a count, and optionally +/// sub_aggregations. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct IntermediateTermBucketEntry { + /// The number of documents in the bucket. + pub doc_count: u64, + /// The sub_aggregation in this bucket. + pub sub_aggregation: IntermediateAggregationResults, +} - IntermediateRangeBucketEntry { - key: entry.key, - doc_count: entry.doc_count, - values: None, - sub_aggregation, - to: entry.to, - from: entry.from, - } +impl MergeFruits for IntermediateTermBucketEntry { + fn merge_fruits(&mut self, other: IntermediateTermBucketEntry) { + self.doc_count += other.doc_count; + self.sub_aggregation.merge_fruits(other.sub_aggregation); } } @@ -366,7 +390,6 @@ mod tests { IntermediateRangeBucketEntry { key: Key::Str(key.to_string()), doc_count: *doc_count, - values: None, sub_aggregation: Default::default(), from: None, to: None, @@ -394,7 +417,6 @@ mod tests { IntermediateRangeBucketEntry { key: Key::Str(key.to_string()), doc_count: *doc_count, - values: None, from: None, to: None, sub_aggregation: get_sub_test_tree(&[( diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index 24e031c8c..24c8caf93 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -318,7 +318,7 @@ mod tests { use crate::aggregation::segment_agg_result::DOC_BLOCK_SIZE; use crate::aggregation::DistributedAggregationCollector; use crate::query::{AllQuery, TermQuery}; - use crate::schema::{Cardinality, IndexRecordOption, Schema, TextFieldIndexing}; + use crate::schema::{Cardinality, IndexRecordOption, Schema, TextFieldIndexing, FAST, STRING}; use crate::{Index, Term}; fn get_avg_req(field_name: &str) -> Aggregation { @@ -337,17 +337,79 @@ mod tests { ) } + pub fn exec_request(agg_req: Aggregations, index: &Index) -> crate::Result { + exec_request_with_query(agg_req, index, None) + } + pub fn exec_request_with_query( + agg_req: Aggregations, + index: &Index, + query: Option<(&str, &str)>, + ) -> crate::Result { + let collector = AggregationCollector::from_aggs(agg_req); + + let reader = index.reader()?; + let searcher = reader.searcher(); + let agg_res = if let Some((field, term)) = query { + let text_field = reader.searcher().schema().get_field(field).unwrap(); + + let term_query = TermQuery::new( + Term::from_field_text(text_field, term), + IndexRecordOption::Basic, + ); + + searcher.search(&term_query, &collector)? + } else { + searcher.search(&AllQuery, &collector)? + }; + + let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?; + Ok(res) + } + pub fn get_test_index_from_values( merge_segments: bool, values: &[f64], + ) -> crate::Result { + // Every value gets its own segment + let mut segment_and_values = vec![]; + for value in values { + segment_and_values.push(vec![(*value, value.to_string())]); + } + get_test_index_from_values_and_terms(merge_segments, &segment_and_values) + } + + pub fn get_test_index_from_terms( + merge_segments: bool, + values: &[Vec<&str>], + ) -> crate::Result { + // Every value gets its own segment + let segment_and_values = values + .iter() + .map(|terms| { + terms + .iter() + .enumerate() + .map(|(i, term)| (i as f64, term.to_string())) + .collect() + }) + .collect::>(); + get_test_index_from_values_and_terms(merge_segments, &segment_and_values) + } + + pub fn get_test_index_from_values_and_terms( + merge_segments: bool, + segment_and_values: &[Vec<(f64, String)>], ) -> crate::Result { let mut schema_builder = Schema::builder(); let text_fieldtype = crate::schema::TextOptions::default() .set_indexing_options( TextFieldIndexing::default().set_index_option(IndexRecordOption::WithFreqs), ) + .set_fast() .set_stored(); - let text_field = schema_builder.add_text_field("text", text_fieldtype); + let text_field = schema_builder.add_text_field("text", text_fieldtype.clone()); + let text_field_id = schema_builder.add_text_field("text_id", text_fieldtype); + let string_field_id = schema_builder.add_text_field("string_id", STRING | FAST); let score_fieldtype = crate::schema::NumericOptions::default().set_fast(Cardinality::SingleValue); let score_field = schema_builder.add_u64_field("score", score_fieldtype.clone()); @@ -360,15 +422,20 @@ mod tests { let index = Index::create_in_ram(schema_builder.build()); { let mut index_writer = index.writer_for_tests()?; - for &i in values { - // writing the segment - index_writer.add_document(doc!( - text_field => "cool", - score_field => i as u64, - score_field_f64 => i as f64, - score_field_i64 => i as i64, - fraction_field => i as f64/100.0, - ))?; + for values in segment_and_values { + for (i, term) in values { + let i = *i; + // writing the segment + index_writer.add_document(doc!( + text_field => "cool", + text_field_id => term.to_string(), + string_field_id => term.to_string(), + score_field => i as u64, + score_field_f64 => i as f64, + score_field_i64 => i as i64, + fraction_field => i as f64/100.0, + ))?; + } index_writer.commit()?; } } @@ -968,7 +1035,7 @@ mod tests { let agg_res = avg_on_field("text"); assert_eq!( format!("{:?}", agg_res), - r#"InvalidArgument("Only single value fast fields of type f64, u64, i64 are supported, but got Str ")"# + r#"InvalidArgument("Only fast fields of type f64, u64, i64 are supported, but got Str ")"# ); let agg_res = avg_on_field("not_exist_field"); @@ -989,11 +1056,12 @@ mod tests { #[cfg(all(test, feature = "unstable"))] mod bench { + use rand::prelude::SliceRandom; use rand::{thread_rng, Rng}; use test::{self, Bencher}; use super::*; - use crate::aggregation::bucket::{HistogramAggregation, HistogramBounds}; + use crate::aggregation::bucket::{HistogramAggregation, HistogramBounds, TermsAggregation}; use crate::aggregation::metric::StatsAggregation; use crate::query::AllQuery; @@ -1005,6 +1073,10 @@ mod tests { ) .set_stored(); let text_field = schema_builder.add_text_field("text", text_fieldtype); + let text_field_many_terms = + schema_builder.add_text_field("text_many_terms", STRING | FAST); + let text_field_few_terms = + schema_builder.add_text_field("text_few_terms", STRING | FAST); let score_fieldtype = crate::schema::NumericOptions::default().set_fast(Cardinality::SingleValue); let score_field = schema_builder.add_u64_field("score", score_fieldtype.clone()); @@ -1012,6 +1084,7 @@ mod tests { schema_builder.add_f64_field("score_f64", score_fieldtype.clone()); let score_field_i64 = schema_builder.add_i64_field("score_i64", score_fieldtype); let index = Index::create_from_tempdir(schema_builder.build())?; + let few_terms_data = vec!["INFO", "ERROR", "WARN", "DEBUG"]; { let mut rng = thread_rng(); let mut index_writer = index.writer_for_tests()?; @@ -1020,6 +1093,8 @@ mod tests { let val: f64 = rng.gen_range(0.0..1_000_000.0); index_writer.add_document(doc!( text_field => "cool", + text_field_many_terms => val.to_string(), + text_field_few_terms => few_terms_data.choose(&mut rng).unwrap().to_string(), score_field => val as u64, score_field_f64 => val as f64, score_field_i64 => val as i64, @@ -1171,6 +1246,64 @@ mod tests { }); } + #[bench] + fn bench_aggregation_terms_few(b: &mut Bencher) { + let index = get_test_index_bench(false).unwrap(); + let reader = index.reader().unwrap(); + + b.iter(|| { + let agg_req: Aggregations = vec![( + "my_texts".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Terms(TermsAggregation { + field: "text_few_terms".to_string(), + ..Default::default() + }), + sub_aggregation: Default::default(), + }), + )] + .into_iter() + .collect(); + + let collector = AggregationCollector::from_aggs(agg_req); + + let searcher = reader.searcher(); + let agg_res: AggregationResults = + searcher.search(&AllQuery, &collector).unwrap().into(); + + agg_res + }); + } + + #[bench] + fn bench_aggregation_terms_many(b: &mut Bencher) { + let index = get_test_index_bench(false).unwrap(); + let reader = index.reader().unwrap(); + + b.iter(|| { + let agg_req: Aggregations = vec![( + "my_texts".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Terms(TermsAggregation { + field: "text_many_terms".to_string(), + ..Default::default() + }), + sub_aggregation: Default::default(), + }), + )] + .into_iter() + .collect(); + + let collector = AggregationCollector::from_aggs(agg_req); + + let searcher = reader.searcher(); + let agg_res: AggregationResults = + searcher.search(&AllQuery, &collector).unwrap().into(); + + agg_res + }); + } + #[bench] fn bench_aggregation_range_only(b: &mut Bencher) { let index = get_test_index_bench(false).unwrap(); diff --git a/src/aggregation/segment_agg_result.rs b/src/aggregation/segment_agg_result.rs index 0064546a3..b0645785f 100644 --- a/src/aggregation/segment_agg_result.rs +++ b/src/aggregation/segment_agg_result.rs @@ -9,11 +9,12 @@ use super::agg_req::MetricAggregation; use super::agg_req_with_accessor::{ AggregationsWithAccessor, BucketAggregationWithAccessor, MetricAggregationWithAccessor, }; -use super::bucket::{SegmentHistogramCollector, SegmentRangeCollector}; +use super::bucket::{SegmentHistogramCollector, SegmentRangeCollector, SegmentTermCollector}; +use super::intermediate_agg_result::{IntermediateAggregationResults, IntermediateBucketResult}; use super::metric::{ AverageAggregation, SegmentAverageCollector, SegmentStatsCollector, StatsAggregation, }; -use super::{Key, VecWithNames}; +use super::VecWithNames; use crate::aggregation::agg_req::BucketAggregationType; use crate::DocId; @@ -40,6 +41,23 @@ impl Debug for SegmentAggregationResultsCollector { } impl SegmentAggregationResultsCollector { + pub fn into_intermediate_aggregations_result( + self, + agg_with_accessor: &AggregationsWithAccessor, + ) -> IntermediateAggregationResults { + let buckets = self.buckets.map(|buckets| { + let entries = buckets + .into_iter() + .zip(agg_with_accessor.buckets.values()) + .map(|((key, bucket), acc)| (key, bucket.into_intermediate_bucket_result(acc))) + .collect::>(); + VecWithNames::from_entries(entries) + }); + let metrics = self.metrics.map(VecWithNames::from_other); + + IntermediateAggregationResults { metrics, buckets } + } + pub(crate) fn from_req_and_validate(req: &AggregationsWithAccessor) -> crate::Result { let buckets = req .buckets @@ -97,6 +115,9 @@ impl SegmentAggregationResultsCollector { agg_with_accessor: &AggregationsWithAccessor, force_flush: bool, ) { + if self.num_staged_docs == 0 { + return; + } if let Some(metrics) = &mut self.metrics { for (collector, agg_with_accessor) in metrics.values_mut().zip(agg_with_accessor.metrics.values()) @@ -162,12 +183,38 @@ impl SegmentMetricResultCollector { #[derive(Clone, Debug, PartialEq)] pub(crate) enum SegmentBucketResultCollector { Range(SegmentRangeCollector), - Histogram(SegmentHistogramCollector), + Histogram(Box), + Terms(Box), } impl SegmentBucketResultCollector { + pub fn into_intermediate_bucket_result( + self, + agg_with_accessor: &BucketAggregationWithAccessor, + ) -> IntermediateBucketResult { + match self { + SegmentBucketResultCollector::Terms(terms) => { + terms.into_intermediate_bucket_result(agg_with_accessor) + } + SegmentBucketResultCollector::Range(range) => { + range.into_intermediate_bucket_result(agg_with_accessor) + } + SegmentBucketResultCollector::Histogram(histogram) => { + histogram.into_intermediate_bucket_result(agg_with_accessor) + } + } + } + pub fn from_req_and_validate(req: &BucketAggregationWithAccessor) -> crate::Result { match &req.bucket_agg { + BucketAggregationType::Terms(terms_req) => Ok(Self::Terms(Box::new( + SegmentTermCollector::from_req_and_validate( + terms_req, + &req.sub_aggregation, + req.field_type, + req.accessor.as_multi(), + )?, + ))), BucketAggregationType::Range(range_req) => { Ok(Self::Range(SegmentRangeCollector::from_req_and_validate( range_req, @@ -175,14 +222,14 @@ impl SegmentBucketResultCollector { req.field_type, )?)) } - BucketAggregationType::Histogram(histogram) => Ok(Self::Histogram( + BucketAggregationType::Histogram(histogram) => Ok(Self::Histogram(Box::new( SegmentHistogramCollector::from_req_and_validate( histogram, &req.sub_aggregation, req.field_type, - &req.accessor, + req.accessor.as_single(), )?, - )), + ))), } } @@ -200,34 +247,9 @@ impl SegmentBucketResultCollector { SegmentBucketResultCollector::Histogram(histogram) => { histogram.collect_block(doc, bucket_with_accessor, force_flush) } + SegmentBucketResultCollector::Terms(terms) => { + terms.collect_block(doc, bucket_with_accessor, force_flush) + } } } } - -#[derive(Clone, Debug, PartialEq)] -pub(crate) struct SegmentHistogramBucketEntry { - pub key: f64, - pub doc_count: u64, -} - -#[derive(Clone, PartialEq)] -pub(crate) struct SegmentRangeBucketEntry { - pub key: Key, - pub doc_count: u64, - pub sub_aggregation: Option, - /// The from range of the bucket. Equals f64::MIN when None. - pub from: Option, - /// The to range of the bucket. Equals f64::MAX when None. - pub to: Option, -} - -impl Debug for SegmentRangeBucketEntry { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("SegmentRangeBucketEntry") - .field("key", &self.key) - .field("doc_count", &self.doc_count) - .field("from", &self.from) - .field("to", &self.to) - .finish() - } -}