add term aggregation

This commit is contained in:
Pascal Seitz
2022-03-30 22:38:26 +08:00
parent 31d3bcfff2
commit 24432bf523
12 changed files with 1093 additions and 195 deletions

View File

@@ -8,8 +8,9 @@ Unreleased
- Converting a `time::OffsetDateTime` to `Value::Date` implicitly converts the value into UTC.
If this is not desired do the time zone conversion yourself and use `time::PrimitiveDateTime`
directly instead.
- Add [histogram](https://github.com/quickwit-oss/tantivy/pull/1306) aggregation (@PSeitz).
- Add support for fastfield on text fields (@PSeitz).
- Add [histogram](https://github.com/quickwit-oss/tantivy/pull/1306) aggregation (@PSeitz)
- Add support for fastfield on text fields (@PSeitz)
- Add terms aggregation (@PSeitz)
Tantivy 0.17
================================

View File

@@ -48,8 +48,8 @@ use std::collections::{HashMap, HashSet};
use serde::{Deserialize, Serialize};
use super::bucket::HistogramAggregation;
pub use super::bucket::RangeAggregation;
use super::bucket::{HistogramAggregation, TermsAggregation};
use super::metric::{AverageAggregation, StatsAggregation};
use super::VecWithNames;
@@ -102,8 +102,14 @@ pub(crate) struct BucketAggregationInternal {
impl BucketAggregationInternal {
pub(crate) fn as_histogram(&self) -> &HistogramAggregation {
match &self.bucket_agg {
BucketAggregationType::Range(_) => panic!("unexpected aggregation"),
BucketAggregationType::Histogram(histogram) => histogram,
_ => panic!("unexpected aggregation"),
}
}
pub(crate) fn as_term(&self) -> &TermsAggregation {
match &self.bucket_agg {
BucketAggregationType::Terms(terms) => terms,
_ => panic!("unexpected aggregation"),
}
}
}
@@ -177,11 +183,15 @@ pub enum BucketAggregationType {
/// Put data into buckets of user-defined ranges.
#[serde(rename = "histogram")]
Histogram(HistogramAggregation),
/// Put data into buckets of terms.
#[serde(rename = "terms")]
Terms(TermsAggregation),
}
impl BucketAggregationType {
fn get_fast_field_names(&self, fast_field_names: &mut HashSet<String>) {
match self {
BucketAggregationType::Terms(terms) => fast_field_names.insert(terms.field.to_string()),
BucketAggregationType::Range(range) => fast_field_names.insert(range.field.to_string()),
BucketAggregationType::Histogram(histogram) => {
fast_field_names.insert(histogram.field.to_string())

View File

@@ -1,12 +1,16 @@
//! This will enhance the request tree with access to the fastfield and metadata.
use std::sync::Arc;
use super::agg_req::{Aggregation, Aggregations, BucketAggregationType, MetricAggregation};
use super::bucket::{HistogramAggregation, RangeAggregation};
use super::bucket::{HistogramAggregation, RangeAggregation, TermsAggregation};
use super::metric::{AverageAggregation, StatsAggregation};
use super::VecWithNames;
use crate::fastfield::{type_and_cardinality, DynamicFastFieldReader, FastType};
use crate::fastfield::{
type_and_cardinality, DynamicFastFieldReader, FastType, MultiValuedFastFieldReader,
};
use crate::schema::{Cardinality, Type};
use crate::{SegmentReader, TantivyError};
use crate::{InvertedIndexReader, SegmentReader, TantivyError};
#[derive(Clone, Default)]
pub(crate) struct AggregationsWithAccessor {
@@ -27,11 +31,32 @@ impl AggregationsWithAccessor {
}
}
#[derive(Clone)]
pub(crate) enum FastFieldAccessor {
Multi(MultiValuedFastFieldReader<u64>),
Single(DynamicFastFieldReader<u64>),
}
impl FastFieldAccessor {
pub fn as_single(&self) -> &DynamicFastFieldReader<u64> {
match self {
FastFieldAccessor::Multi(_) => panic!("unexpected ff cardinality"),
FastFieldAccessor::Single(reader) => reader,
}
}
pub fn as_multi(&self) -> &MultiValuedFastFieldReader<u64> {
match self {
FastFieldAccessor::Multi(reader) => reader,
FastFieldAccessor::Single(_) => panic!("unexpected ff cardinality"),
}
}
}
#[derive(Clone)]
pub struct BucketAggregationWithAccessor {
/// In general there can be buckets without fast field access, e.g. buckets that are created
/// based on search terms. So eventually this needs to be Option or moved.
pub(crate) accessor: DynamicFastFieldReader<u64>,
pub(crate) accessor: FastFieldAccessor,
pub(crate) inverted_index: Option<Arc<InvertedIndexReader>>,
pub(crate) field_type: Type,
pub(crate) bucket_agg: BucketAggregationType,
pub(crate) sub_aggregation: AggregationsWithAccessor,
@@ -43,14 +68,25 @@ impl BucketAggregationWithAccessor {
sub_aggregation: &Aggregations,
reader: &SegmentReader,
) -> crate::Result<BucketAggregationWithAccessor> {
let mut inverted_index = None;
let (accessor, field_type) = match &bucket {
BucketAggregationType::Range(RangeAggregation {
field: field_name,
ranges: _,
}) => get_ff_reader_and_validate(reader, field_name)?,
}) => get_ff_reader_and_validate(reader, field_name, false)?,
BucketAggregationType::Histogram(HistogramAggregation {
field: field_name, ..
}) => get_ff_reader_and_validate(reader, field_name)?,
}) => get_ff_reader_and_validate(reader, field_name, false)?,
BucketAggregationType::Terms(TermsAggregation {
field: field_name, ..
}) => {
let field = reader
.schema()
.get_field(field_name)
.ok_or_else(|| TantivyError::FieldNotFound(field_name.to_string()))?;
inverted_index = Some(reader.inverted_index(field)?);
get_ff_reader_and_validate(reader, field_name, true)?
}
};
let sub_aggregation = sub_aggregation.clone();
Ok(BucketAggregationWithAccessor {
@@ -58,6 +94,7 @@ impl BucketAggregationWithAccessor {
field_type,
sub_aggregation: get_aggs_with_accessor_and_validate(&sub_aggregation, reader)?,
bucket_agg: bucket.clone(),
inverted_index,
})
}
}
@@ -78,10 +115,10 @@ impl MetricAggregationWithAccessor {
match &metric {
MetricAggregation::Average(AverageAggregation { field: field_name })
| MetricAggregation::Stats(StatsAggregation { field: field_name }) => {
let (accessor, field_type) = get_ff_reader_and_validate(reader, field_name)?;
let (accessor, field_type) = get_ff_reader_and_validate(reader, field_name, false)?;
Ok(MetricAggregationWithAccessor {
accessor,
accessor: accessor.as_single().clone(),
field_type,
metric: metric.clone(),
})
@@ -121,7 +158,8 @@ pub(crate) fn get_aggs_with_accessor_and_validate(
fn get_ff_reader_and_validate(
reader: &SegmentReader,
field_name: &str,
) -> crate::Result<(DynamicFastFieldReader<u64>, Type)> {
multi: bool,
) -> crate::Result<(FastFieldAccessor, Type)> {
let field = reader
.schema()
.get_field(field_name)
@@ -129,7 +167,7 @@ fn get_ff_reader_and_validate(
let field_type = reader.schema().get_field_entry(field).field_type();
if let Some((ff_type, cardinality)) = type_and_cardinality(field_type) {
if cardinality == Cardinality::MultiValues || ff_type == FastType::Date {
if (!multi && cardinality == Cardinality::MultiValues) || ff_type == FastType::Date {
return Err(TantivyError::InvalidArgument(format!(
"Invalid field type in aggregation {:?}, only Cardinality::SingleValue supported",
field_type.value_type()
@@ -137,13 +175,19 @@ fn get_ff_reader_and_validate(
}
} else {
return Err(TantivyError::InvalidArgument(format!(
"Only single value fast fields of type f64, u64, i64 are supported, but got {:?} ",
"Only fast fields of type f64, u64, i64 are supported, but got {:?} ",
field_type.value_type()
)));
};
let ff_fields = reader.fast_fields();
ff_fields
.u64_lenient(field)
.map(|field| (field, field_type.value_type()))
if multi {
ff_fields
.u64s_lenient(field)
.map(|field| (FastFieldAccessor::Multi(field), field_type.value_type()))
} else {
ff_fields
.u64_lenient(field)
.map(|field| (FastFieldAccessor::Single(field), field_type.value_type()))
}
}

View File

@@ -11,7 +11,7 @@ use itertools::Itertools;
use serde::{Deserialize, Serialize};
use super::agg_req::{Aggregations, AggregationsInternal, BucketAggregationInternal};
use super::bucket::intermediate_buckets_to_final_buckets;
use super::bucket::{intermediate_buckets_to_final_buckets, GetDocCount};
use super::intermediate_agg_result::{
IntermediateAggregationResults, IntermediateBucketResult, IntermediateHistogramBucketEntry,
IntermediateMetricResult, IntermediateRangeBucketEntry,
@@ -34,8 +34,8 @@ impl AggregationResults {
/// Convert and intermediate result and its aggregation request to the final result
///
/// Internal function, CollectorAggregations is used instead Aggregations, which is optimized
/// for internal processing
fn from_intermediate_and_req_internal(
/// for internal processing, by splitting metric and buckets into seperate groups.
pub(crate) fn from_intermediate_and_req_internal(
results: IntermediateAggregationResults,
req: &AggregationsInternal,
) -> Self {
@@ -140,6 +140,18 @@ pub enum BucketResult {
/// See [HistogramAggregation](super::bucket::HistogramAggregation)
buckets: Vec<BucketEntry>,
},
/// This is the term result
Terms {
/// The buckets.
///
/// See [TermsAggregation](super::bucket::TermsAggregation)
buckets: Vec<BucketEntry>,
/// The number of documents that didnt make it into to TOP N due to shard_size or size
sum_other_doc_count: u64,
#[serde(skip_serializing_if = "Option::is_none")]
/// The upper bound error for the doc count of each term.
doc_count_error_upper_bound: Option<u64>,
},
}
impl BucketResult {
@@ -173,6 +185,9 @@ impl BucketResult {
BucketResult::Histogram { buckets }
}
IntermediateBucketResult::Terms(terms) => {
terms.into_final_result(req.as_term(), &req.sub_aggregation)
}
}
}
}
@@ -229,6 +244,11 @@ impl BucketEntry {
}
}
}
impl GetDocCount for BucketEntry {
fn doc_count(&self) -> u64 {
self.doc_count
}
}
/// This is the range entry for a bucket, which contains a key, count, and optionally
/// sub_aggregations.

View File

@@ -13,9 +13,7 @@ use crate::aggregation::f64_from_fastfield_u64;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResults, IntermediateBucketResult, IntermediateHistogramBucketEntry,
};
use crate::aggregation::segment_agg_result::{
SegmentAggregationResultsCollector, SegmentHistogramBucketEntry,
};
use crate::aggregation::segment_agg_result::SegmentAggregationResultsCollector;
use crate::fastfield::{DynamicFastFieldReader, FastFieldReader};
use crate::schema::Type;
use crate::{DocId, TantivyError};
@@ -159,6 +157,27 @@ impl HistogramBounds {
}
}
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct SegmentHistogramBucketEntry {
pub key: f64,
pub doc_count: u64,
}
impl SegmentHistogramBucketEntry {
pub(crate) fn into_intermediate_bucket_entry(
self,
sub_aggregation: SegmentAggregationResultsCollector,
agg_with_accessor: &AggregationsWithAccessor,
) -> IntermediateHistogramBucketEntry {
IntermediateHistogramBucketEntry {
key: self.key,
doc_count: self.doc_count,
sub_aggregation: sub_aggregation
.into_intermediate_aggregations_result(agg_with_accessor),
}
}
}
/// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype.
#[derive(Clone, Debug, PartialEq)]
@@ -174,7 +193,10 @@ pub struct SegmentHistogramCollector {
}
impl SegmentHistogramCollector {
pub fn into_intermediate_bucket_result(self) -> IntermediateBucketResult {
pub fn into_intermediate_bucket_result(
self,
agg_with_accessor: &BucketAggregationWithAccessor,
) -> IntermediateBucketResult {
let mut buckets = Vec::with_capacity(
self.buckets
.iter()
@@ -193,7 +215,12 @@ impl SegmentHistogramCollector {
.into_iter()
.zip(sub_aggregations.into_iter())
.filter(|(bucket, _sub_aggregation)| bucket.doc_count != 0)
.map(|(bucket, sub_aggregation)| (bucket, sub_aggregation).into()),
.map(|(bucket, sub_aggregation)| {
bucket.into_intermediate_bucket_entry(
sub_aggregation,
&agg_with_accessor.sub_aggregation,
)
}),
)
} else {
buckets.extend(
@@ -273,12 +300,13 @@ impl SegmentHistogramCollector {
let get_bucket_num =
|val| (get_bucket_num_f64(val, interval, offset) as i64 - first_bucket_num) as usize;
let accessor = bucket_with_accessor.accessor.as_single();
let mut iter = doc.chunks_exact(4);
for docs in iter.by_ref() {
let val0 = self.f64_from_fastfield_u64(bucket_with_accessor.accessor.get(docs[0]));
let val1 = self.f64_from_fastfield_u64(bucket_with_accessor.accessor.get(docs[1]));
let val2 = self.f64_from_fastfield_u64(bucket_with_accessor.accessor.get(docs[2]));
let val3 = self.f64_from_fastfield_u64(bucket_with_accessor.accessor.get(docs[3]));
let val0 = self.f64_from_fastfield_u64(accessor.get(docs[0]));
let val1 = self.f64_from_fastfield_u64(accessor.get(docs[1]));
let val2 = self.f64_from_fastfield_u64(accessor.get(docs[2]));
let val3 = self.f64_from_fastfield_u64(accessor.get(docs[3]));
let bucket_pos0 = get_bucket_num(val0);
let bucket_pos1 = get_bucket_num(val1);
@@ -315,8 +343,7 @@ impl SegmentHistogramCollector {
);
}
for doc in iter.remainder() {
let val =
f64_from_fastfield_u64(bucket_with_accessor.accessor.get(*doc), &self.field_type);
let val = f64_from_fastfield_u64(accessor.get(*doc), &self.field_type);
if !bounds.contains(val) {
continue;
}
@@ -630,41 +657,9 @@ mod tests {
};
use crate::aggregation::metric::{AverageAggregation, StatsAggregation};
use crate::aggregation::tests::{
get_test_index_2_segments, get_test_index_from_values, get_test_index_with_num_docs,
exec_request, exec_request_with_query, get_test_index_2_segments,
get_test_index_from_values, get_test_index_with_num_docs,
};
use crate::aggregation::AggregationCollector;
use crate::query::{AllQuery, TermQuery};
use crate::schema::IndexRecordOption;
use crate::{Index, Term};
fn exec_request(agg_req: Aggregations, index: &Index) -> crate::Result<Value> {
exec_request_with_query(agg_req, index, None)
}
fn exec_request_with_query(
agg_req: Aggregations,
index: &Index,
query: Option<(&str, &str)>,
) -> crate::Result<Value> {
let collector = AggregationCollector::from_aggs(agg_req);
let reader = index.reader()?;
let searcher = reader.searcher();
let agg_res = if let Some((field, term)) = query {
let text_field = reader.searcher().schema().get_field(field).unwrap();
let term_query = TermQuery::new(
Term::from_field_text(text_field, term),
IndexRecordOption::Basic,
);
searcher.search(&term_query, &collector)?
} else {
searcher.search(&AllQuery, &collector)?
};
let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?;
Ok(res)
}
#[test]
fn histogram_test_crooked_values() -> crate::Result<()> {

View File

@@ -9,8 +9,10 @@
mod histogram;
mod range;
mod term_agg;
pub(crate) use histogram::SegmentHistogramCollector;
pub use histogram::*;
pub(crate) use range::SegmentRangeCollector;
pub use range::*;
pub use term_agg::*;

View File

@@ -1,3 +1,4 @@
use std::fmt::Debug;
use std::ops::Range;
use serde::{Deserialize, Serialize};
@@ -5,10 +6,10 @@ use serde::{Deserialize, Serialize};
use crate::aggregation::agg_req_with_accessor::{
AggregationsWithAccessor, BucketAggregationWithAccessor,
};
use crate::aggregation::intermediate_agg_result::IntermediateBucketResult;
use crate::aggregation::segment_agg_result::{
SegmentAggregationResultsCollector, SegmentRangeBucketEntry,
use crate::aggregation::intermediate_agg_result::{
IntermediateBucketResult, IntermediateRangeBucketEntry,
};
use crate::aggregation::segment_agg_result::SegmentAggregationResultsCollector;
use crate::aggregation::{f64_from_fastfield_u64, f64_to_fastfield_u64, Key};
use crate::fastfield::FastFieldReader;
use crate::schema::Type;
@@ -102,8 +103,53 @@ pub struct SegmentRangeCollector {
field_type: Type,
}
#[derive(Clone, PartialEq)]
pub(crate) struct SegmentRangeBucketEntry {
pub key: Key,
pub doc_count: u64,
pub sub_aggregation: Option<SegmentAggregationResultsCollector>,
/// The from range of the bucket. Equals f64::MIN when None.
pub from: Option<f64>,
/// The to range of the bucket. Equals f64::MAX when None.
pub to: Option<f64>,
}
impl Debug for SegmentRangeBucketEntry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SegmentRangeBucketEntry")
.field("key", &self.key)
.field("doc_count", &self.doc_count)
.field("from", &self.from)
.field("to", &self.to)
.finish()
}
}
impl SegmentRangeBucketEntry {
pub(crate) fn into_intermediate_bucket_entry(
self,
agg_with_accessor: &AggregationsWithAccessor,
) -> IntermediateRangeBucketEntry {
let sub_aggregation = if let Some(sub_aggregation) = self.sub_aggregation {
sub_aggregation.into_intermediate_aggregations_result(agg_with_accessor)
} else {
Default::default()
};
IntermediateRangeBucketEntry {
key: self.key,
doc_count: self.doc_count,
sub_aggregation,
from: self.from,
to: self.to,
}
}
}
impl SegmentRangeCollector {
pub fn into_intermediate_bucket_result(self) -> IntermediateBucketResult {
pub fn into_intermediate_bucket_result(
self,
agg_with_accessor: &BucketAggregationWithAccessor,
) -> IntermediateBucketResult {
let field_type = self.field_type;
let buckets = self
@@ -112,7 +158,9 @@ impl SegmentRangeCollector {
.map(move |range_bucket| {
(
range_to_string(&range_bucket.range, &field_type),
range_bucket.bucket.into(),
range_bucket
.bucket
.into_intermediate_bucket_entry(&agg_with_accessor.sub_aggregation),
)
})
.collect();
@@ -175,11 +223,12 @@ impl SegmentRangeCollector {
force_flush: bool,
) {
let mut iter = doc.chunks_exact(4);
let accessor = bucket_with_accessor.accessor.as_single();
for docs in iter.by_ref() {
let val1 = bucket_with_accessor.accessor.get(docs[0]);
let val2 = bucket_with_accessor.accessor.get(docs[1]);
let val3 = bucket_with_accessor.accessor.get(docs[2]);
let val4 = bucket_with_accessor.accessor.get(docs[3]);
let val1 = accessor.get(docs[0]);
let val2 = accessor.get(docs[1]);
let val3 = accessor.get(docs[2]);
let val4 = accessor.get(docs[3]);
let bucket_pos1 = self.get_bucket_pos(val1);
let bucket_pos2 = self.get_bucket_pos(val2);
let bucket_pos3 = self.get_bucket_pos(val3);
@@ -191,7 +240,7 @@ impl SegmentRangeCollector {
self.increment_bucket(bucket_pos4, docs[3], &bucket_with_accessor.sub_aggregation);
}
for doc in iter.remainder() {
let val = bucket_with_accessor.accessor.get(*doc);
let val = accessor.get(*doc);
let bucket_pos = self.get_bucket_pos(val);
self.increment_bucket(bucket_pos, *doc, &bucket_with_accessor.sub_aggregation);
}
@@ -487,11 +536,7 @@ mod tests {
#[test]
fn range_binary_search_test_f64() {
let ranges = vec![
//(f64::MIN..10.0).into(),
(10.0..100.0).into(),
//(100.0..f64::MAX).into(),
];
let ranges = vec![(10.0..100.0).into()];
let collector = get_collector_from_ranges(ranges, Type::F64);
let search = |val: u64| collector.get_bucket_pos(val);

View File

@@ -0,0 +1,602 @@
use std::fmt::Debug;
use fnv::FnvHashMap;
use serde::{Deserialize, Serialize};
use crate::aggregation::agg_req_with_accessor::{
AggregationsWithAccessor, BucketAggregationWithAccessor,
};
use crate::aggregation::intermediate_agg_result::{
IntermediateBucketResult, IntermediateTermBucketEntry, IntermediateTermBucketResult,
};
use crate::aggregation::segment_agg_result::SegmentAggregationResultsCollector;
use crate::fastfield::MultiValuedFastFieldReader;
use crate::schema::Type;
use crate::DocId;
/// Creates one bucket for every unique term
///
/// ### Terminology
/// Shard and Segment are equivalent.
///
/// ## Document count error
/// To improve performance, results from one segment are cut off at `shard_size`. On a single
/// segment this is fine. When combining results of multiple segments, terms that
/// don't make it in the top n of a shard increase the theoretical upper bound error by lowest
/// term-count.
///
/// Even with a larger `shard_size` value, doc_count values for a terms aggregation may be
/// approximate. As a result, any sub-aggregations on the terms aggregation may also be approximate.
/// sum_other_doc_count is the number of documents that didnt make it into the the top size terms.
/// If this is greater than 0, you can be sure that the terms agg had to throw away some buckets,
/// either because they didnt fit into size on the root node or they didnt fit into
/// shard_size on the leaf node.
///
/// ## Per bucket document count error
/// If you set the show_term_doc_count_error parameter to true, the terms aggregation will include
/// doc_count_error_upper_bound, which is an upper bound to the error on the doc_count returned by
/// each shard. Its the sum of the size of the largest bucket on each shard that didnt fit into
/// shard_size.
///
/// Result type is [BucketResult](crate::aggregation::agg_result::BucketResult) with
/// [RangeBucketEntry](crate::aggregation::agg_result::RangeBucketEntry) on the
/// AggregationCollector.
///
/// Result type is
/// [crate::aggregation::intermediate_agg_result::IntermediateBucketResult] with
/// [crate::aggregation::intermediate_agg_result::IntermediateRangeBucketEntry] on the
/// DistributedAggregationCollector.
///
/// # Limitations/Compatibility
///
/// # Request JSON Format
/// ```json
/// {
/// "genres": {
/// "field": "genre",
/// }
/// }
/// ```
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct TermsAggregation {
/// The field to aggregate on.
pub field: String,
/// By default, the top 10 terms with the most documents are returned.
/// Larger values for size are more expensive.
pub size: Option<u32>,
/// The get more accurate results, we fetch more than `size` from each segment.
/// By default we fetch `shard_size` terms, which defaults to size * 1.5 + 10.
pub shard_size: Option<u32>,
/// If you set the `show_term_doc_count_error` parameter to true, the terms aggregation will
/// include doc_count_error_upper_bound, which is an upper bound to the error on the
/// doc_count returned by each shard. Its the sum of the size of the largest bucket on
/// each segment that didnt fit into `shard_size`.
#[serde(default = "default_show_term_doc_count_error")]
pub show_term_doc_count_error: bool,
/// Filter all terms than are lower `min_doc_count`.
pub min_doc_count: Option<usize>,
}
impl Default for TermsAggregation {
fn default() -> Self {
Self {
field: Default::default(),
size: Default::default(),
shard_size: Default::default(),
show_term_doc_count_error: true,
min_doc_count: Default::default(),
}
}
}
fn default_show_term_doc_count_error() -> bool {
true
}
/// Same as TermsAggregation, but with populated defaults.
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct TermsAggregationInternal {
/// The field to aggregate on.
pub field: String,
/// By default, the top 10 terms with the most documents are returned.
/// Larger values for size are more expensive.
pub size: u32,
/// The get more accurate results, we fetch more than `size` from each segment.
/// By default we fetch `shard_size` terms, which defaults to size * 1.5 + 10.
///
/// Cannot be smaller than size. In that case it will be set automatically to size.
pub shard_size: u32,
/// If you set the `show_term_doc_count_error` parameter to true, the terms aggregation will
/// include doc_count_error_upper_bound, which is an upper bound to the error on the
/// doc_count returned by each shard. Its the sum of the size of the largest bucket on
/// each segment that didnt fit into `shard_size`.
pub show_term_doc_count_error: bool,
/// Filter all terms than are lower `min_doc_count`.
pub min_doc_count: Option<usize>,
}
impl TermsAggregationInternal {
pub(crate) fn from_req(req: &TermsAggregation) -> Self {
let size = req.size.unwrap_or(10);
let mut shard_size = req
.shard_size
.unwrap_or((size as f32 * 1.5_f32) as u32 + 10);
shard_size = shard_size.max(size);
TermsAggregationInternal {
field: req.field.to_string(),
size,
shard_size,
show_term_doc_count_error: req.show_term_doc_count_error,
min_doc_count: req.min_doc_count,
}
}
}
const TERM_BUCKET_SIZE: usize = 100;
#[derive(Clone, Debug, PartialEq)]
/// Chunks the term_id value range in TERM_BUCKET_SIZE blocks.
struct TermBuckets {
pub(crate) entries: FnvHashMap<u32, TermBucketEntry>,
blueprint: Option<SegmentAggregationResultsCollector>,
}
#[derive(Clone, PartialEq, Default)]
struct TermBucketEntry {
doc_count: u64,
sub_aggregations: Option<SegmentAggregationResultsCollector>,
}
impl Debug for TermBucketEntry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TermBucketEntry")
.field("doc_count", &self.doc_count)
.finish()
}
}
impl TermBucketEntry {
fn from_blueprint(blueprint: &Option<SegmentAggregationResultsCollector>) -> Self {
Self {
doc_count: 0,
sub_aggregations: blueprint.clone(),
}
}
pub(crate) fn into_intermediate_bucket_entry(
self,
agg_with_accessor: &AggregationsWithAccessor,
) -> IntermediateTermBucketEntry {
let sub_aggregation = if let Some(sub_aggregation) = self.sub_aggregations {
sub_aggregation.into_intermediate_aggregations_result(agg_with_accessor)
} else {
Default::default()
};
IntermediateTermBucketEntry {
doc_count: self.doc_count,
sub_aggregation,
}
}
}
impl TermBuckets {
pub(crate) fn from_req_and_validate(
sub_aggregation: &AggregationsWithAccessor,
max_term_id: usize,
) -> crate::Result<Self> {
let has_sub_aggregations = sub_aggregation.is_empty();
let _num_chunks = (max_term_id / TERM_BUCKET_SIZE) + 1;
let blueprint = if has_sub_aggregations {
let sub_aggregation =
SegmentAggregationResultsCollector::from_req_and_validate(sub_aggregation)?;
Some(sub_aggregation)
} else {
None
};
Ok(TermBuckets {
blueprint,
entries: Default::default(),
})
}
fn increment_bucket(
&mut self,
term_ids: &[u64],
doc: DocId,
bucket_with_accessor: &AggregationsWithAccessor,
blueprint: &Option<SegmentAggregationResultsCollector>,
) {
// self.ensure_vec_exists(term_ids);
for &term_id in term_ids {
let entry = self
.entries
.entry(term_id as u32)
.or_insert_with(|| TermBucketEntry::from_blueprint(blueprint));
entry.doc_count += 1;
if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() {
sub_aggregations.collect(doc, bucket_with_accessor);
}
}
}
fn force_flush(&mut self, agg_with_accessor: &AggregationsWithAccessor) {
for entry in &mut self.entries.values_mut() {
if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() {
sub_aggregations.flush_staged_docs(agg_with_accessor, false);
}
}
}
}
/// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype.
#[derive(Clone, Debug, PartialEq)]
pub struct SegmentTermCollector {
/// The buckets containing the aggregation data.
term_buckets: TermBuckets,
req: TermsAggregationInternal,
field_type: Type,
blueprint: Option<SegmentAggregationResultsCollector>,
}
impl SegmentTermCollector {
pub(crate) fn from_req_and_validate(
req: &TermsAggregation,
sub_aggregations: &AggregationsWithAccessor,
field_type: Type,
accessor: &MultiValuedFastFieldReader<u64>,
) -> crate::Result<Self> {
let max_term_id = accessor.max_value();
let term_buckets =
TermBuckets::from_req_and_validate(sub_aggregations, max_term_id as usize)?;
let has_sub_aggregations = sub_aggregations.is_empty();
let blueprint = if has_sub_aggregations {
let sub_aggregation =
SegmentAggregationResultsCollector::from_req_and_validate(sub_aggregations)?;
Some(sub_aggregation)
} else {
None
};
Ok(SegmentTermCollector {
req: TermsAggregationInternal::from_req(req),
term_buckets,
field_type,
blueprint,
})
}
pub(crate) fn into_intermediate_bucket_result(
self,
agg_with_accessor: &BucketAggregationWithAccessor,
) -> IntermediateBucketResult {
let mut entries: Vec<_> = self.term_buckets.entries.into_iter().collect();
let (term_doc_count_before_cutoff, sum_other_doc_count) =
cut_off_buckets(&mut entries, self.req.shard_size as usize);
let inverted_index = agg_with_accessor
.inverted_index
.as_ref()
.expect("internal error: inverted index not loaded for term aggregation");
let term_dict = inverted_index.terms();
let mut dict: FnvHashMap<String, IntermediateTermBucketEntry> = Default::default();
let mut buffer = vec![];
for (term_id, entry) in entries {
term_dict
.ord_to_term(term_id as u64, &mut buffer)
.expect("could not find term");
dict.insert(
String::from_utf8(buffer.to_vec()).unwrap(),
entry.into_intermediate_bucket_entry(&agg_with_accessor.sub_aggregation),
);
}
IntermediateBucketResult::Terms(IntermediateTermBucketResult {
entries: dict,
sum_other_doc_count,
doc_count_error_upper_bound: term_doc_count_before_cutoff,
})
}
#[inline]
pub(crate) fn collect_block(
&mut self,
doc: &[DocId],
bucket_with_accessor: &BucketAggregationWithAccessor,
force_flush: bool,
) {
let accessor = bucket_with_accessor.accessor.as_multi();
let mut iter = doc.chunks_exact(4);
let mut vals1 = vec![];
let mut vals2 = vec![];
let mut vals3 = vec![];
let mut vals4 = vec![];
for docs in iter.by_ref() {
accessor.get_vals(docs[0], &mut vals1);
accessor.get_vals(docs[1], &mut vals2);
accessor.get_vals(docs[2], &mut vals3);
accessor.get_vals(docs[3], &mut vals4);
self.term_buckets.increment_bucket(
&vals1,
docs[0],
&bucket_with_accessor.sub_aggregation,
&self.blueprint,
);
self.term_buckets.increment_bucket(
&vals2,
docs[1],
&bucket_with_accessor.sub_aggregation,
&self.blueprint,
);
self.term_buckets.increment_bucket(
&vals3,
docs[2],
&bucket_with_accessor.sub_aggregation,
&self.blueprint,
);
self.term_buckets.increment_bucket(
&vals4,
docs[3],
&bucket_with_accessor.sub_aggregation,
&self.blueprint,
);
}
for &doc in iter.remainder() {
accessor.get_vals(doc, &mut vals1);
self.term_buckets.increment_bucket(
&vals1,
doc,
&bucket_with_accessor.sub_aggregation,
&self.blueprint,
);
}
if force_flush {
self.term_buckets
.force_flush(&bucket_with_accessor.sub_aggregation);
}
}
}
pub(crate) trait GetDocCount {
fn doc_count(&self) -> u64;
}
impl GetDocCount for (u32, TermBucketEntry) {
fn doc_count(&self) -> u64 {
self.1.doc_count
}
}
pub(crate) fn cut_off_buckets<T: GetDocCount + Debug>(
entries: &mut Vec<T>,
num_elem: usize,
) -> (u64, u64) {
entries.sort_unstable_by_key(|entry| std::cmp::Reverse(entry.doc_count()));
let term_doc_count_before_cutoff = entries
.get(num_elem)
.map(|entry| entry.doc_count())
.unwrap_or(0);
let sum_other_doc_count = entries
.get(num_elem..)
.map(|cut_off_range| cut_off_range.iter().map(|entry| entry.doc_count()).sum())
.unwrap_or(0);
entries.truncate(num_elem);
(term_doc_count_before_cutoff, sum_other_doc_count)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::aggregation::agg_req::{
Aggregation, Aggregations, BucketAggregation, BucketAggregationType,
};
use crate::aggregation::tests::{exec_request, get_test_index_from_terms};
#[test]
fn terms_aggregation_test_single_segment() -> crate::Result<()> {
terms_aggregation_test_merge_segment(true)
}
#[test]
fn terms_aggregation_test() -> crate::Result<()> {
terms_aggregation_test_merge_segment(false)
}
fn terms_aggregation_test_merge_segment(merge_segments: bool) -> crate::Result<()> {
let segment_and_terms = vec![
vec!["terma"],
vec!["termb"],
vec!["termc"],
vec!["terma"],
vec!["terma"],
vec!["terma"],
vec!["termb"],
vec!["terma"],
];
let index = get_test_index_from_terms(merge_segments, &segment_and_terms)?;
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
)]
.into_iter()
.collect();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "terma");
assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 5);
assert_eq!(res["my_texts"]["buckets"][1]["key"], "termb");
assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 2);
assert_eq!(res["my_texts"]["buckets"][2]["doc_count"], 1);
assert_eq!(res["my_texts"]["buckets"][2]["key"], "termc");
assert_eq!(res["my_texts"]["sum_other_doc_count"], 0);
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
size: Some(2),
shard_size: Some(2),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
)]
.into_iter()
.collect();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "terma");
assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 5);
assert_eq!(res["my_texts"]["buckets"][1]["key"], "termb");
assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 2);
assert_eq!(
res["my_texts"]["buckets"][2]["key"],
serde_json::Value::Null
);
assert_eq!(res["my_texts"]["sum_other_doc_count"], 1);
Ok(())
}
#[test]
fn terms_aggregation_error_count_test() -> crate::Result<()> {
let terms_per_segment = vec![
vec!["terma", "terma", "termb", "termb", "termb", "termc"], /* termc doesn't make it
* from this segment */
vec!["terma", "terma", "termb", "termc", "termc"], /* termb doesn't make it from
* this segment */
];
let index = get_test_index_from_terms(false, &terms_per_segment)?;
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
size: Some(2),
shard_size: Some(2),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
)]
.into_iter()
.collect();
let res = exec_request(agg_req, &index)?;
println!("{}", &serde_json::to_string_pretty(&res).unwrap());
assert_eq!(res["my_texts"]["buckets"][0]["key"], "terma");
assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 4);
assert_eq!(res["my_texts"]["buckets"][1]["key"], "termb");
assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 3);
assert_eq!(
res["my_texts"]["buckets"][2]["doc_count"],
serde_json::Value::Null
);
assert_eq!(res["my_texts"]["sum_other_doc_count"], 4);
assert_eq!(res["my_texts"]["doc_count_error_upper_bound"], 2);
Ok(())
}
}
#[cfg(all(test, feature = "unstable"))]
mod bench {
use fnv::FnvHashMap;
use itertools::Itertools;
use rand::seq::SliceRandom;
use rand::thread_rng;
use super::*;
fn get_collector_with_buckets(num_docs: u64) -> TermBuckets {
TermBuckets::from_req_and_validate(&Default::default(), num_docs as usize).unwrap()
}
fn get_rand_terms(total_terms: u64, num_terms_returned: u64) -> Vec<u64> {
let mut rng = thread_rng();
let all_terms = (0..total_terms - 1).collect_vec();
let mut vals = vec![];
for _ in 0..num_terms_returned {
let val = all_terms.as_slice().choose(&mut rng).unwrap();
vals.push(*val);
}
vals
}
fn bench_term_hashmap(b: &mut test::Bencher, num_terms: u64, total_terms: u64) {
let mut collector = FnvHashMap::default();
let vals = get_rand_terms(total_terms, num_terms);
b.iter(|| {
for val in &vals {
let val = collector.entry(val).or_insert(TermBucketEntry::default());
val.doc_count += 1;
}
collector.get(&0).cloned()
})
}
fn bench_term_buckets(b: &mut test::Bencher, num_terms: u64, total_terms: u64) {
let mut collector = get_collector_with_buckets(total_terms);
let vals = get_rand_terms(total_terms, num_terms);
let aggregations_with_accessor: AggregationsWithAccessor = Default::default();
b.iter(|| {
for &val in &vals {
collector.increment_bucket(&[val], 0, &aggregations_with_accessor, &None);
}
})
}
#[bench]
fn bench_term_buckets_500_of_1_000_000(b: &mut test::Bencher) {
bench_term_buckets(b, 500u64, 1_000_000u64)
}
#[bench]
fn bench_fnv_buckets_500_of_1_000_000(b: &mut test::Bencher) {
bench_term_hashmap(b, 500u64, 1_000_000u64)
}
#[bench]
fn bench_term_buckets_1_000_000_of_50_000(b: &mut test::Bencher) {
bench_term_buckets(b, 1_000_000u64, 50_000u64)
}
#[bench]
fn bench_fnv_buckets_1_000_000_of_50_000(b: &mut test::Bencher) {
bench_term_hashmap(b, 1_000_000u64, 50_000u64)
}
#[bench]
fn bench_term_buckets_1_000_000_of_50(b: &mut test::Bencher) {
bench_term_buckets(b, 1_000_000u64, 50u64)
}
#[bench]
fn bench_fnv_buckets_1_000_000_of_50(b: &mut test::Bencher) {
bench_term_hashmap(b, 1_000_000u64, 50u64)
}
}

View File

@@ -106,7 +106,7 @@ fn merge_fruits(
/// AggregationSegmentCollector does the aggregation collection on a segment.
pub struct AggregationSegmentCollector {
aggs: AggregationsWithAccessor,
aggs_with_accessor: AggregationsWithAccessor,
result: SegmentAggregationResultsCollector,
}
@@ -121,7 +121,7 @@ impl AggregationSegmentCollector {
let result =
SegmentAggregationResultsCollector::from_req_and_validate(&aggs_with_accessor)?;
Ok(AggregationSegmentCollector {
aggs: aggs_with_accessor,
aggs_with_accessor,
result,
})
}
@@ -132,11 +132,13 @@ impl SegmentCollector for AggregationSegmentCollector {
#[inline]
fn collect(&mut self, doc: crate::DocId, _score: crate::Score) {
self.result.collect(doc, &self.aggs);
self.result.collect(doc, &self.aggs_with_accessor);
}
fn harvest(mut self) -> Self::Fruit {
self.result.flush_staged_docs(&self.aggs, true);
self.result.into()
self.result
.flush_staged_docs(&self.aggs_with_accessor, true);
self.result
.into_intermediate_aggregations_result(&self.aggs_with_accessor)
}
}

View File

@@ -9,12 +9,13 @@ use itertools::Itertools;
use serde::{Deserialize, Serialize};
use super::agg_req::{AggregationsInternal, BucketAggregationType, MetricAggregation};
use super::agg_result::BucketResult;
use super::bucket::{cut_off_buckets, SegmentHistogramBucketEntry, TermsAggregation};
use super::metric::{IntermediateAverage, IntermediateStats};
use super::segment_agg_result::{
SegmentAggregationResultsCollector, SegmentBucketResultCollector, SegmentHistogramBucketEntry,
SegmentMetricResultCollector, SegmentRangeBucketEntry,
};
use super::segment_agg_result::SegmentMetricResultCollector;
use super::{Key, SerializedKey, VecWithNames};
use crate::aggregation::agg_result::{AggregationResults, BucketEntry};
use crate::aggregation::bucket::TermsAggregationInternal;
/// Contains the intermediate aggregation result, which is optimized to be merged with other
/// intermediate results.
@@ -24,15 +25,6 @@ pub struct IntermediateAggregationResults {
pub(crate) buckets: Option<VecWithNames<IntermediateBucketResult>>,
}
impl From<SegmentAggregationResultsCollector> for IntermediateAggregationResults {
fn from(tree: SegmentAggregationResultsCollector) -> Self {
let metrics = tree.metrics.map(VecWithNames::from_other);
let buckets = tree.buckets.map(VecWithNames::from_other);
Self { metrics, buckets }
}
}
impl IntermediateAggregationResults {
pub(crate) fn empty_from_req(req: &AggregationsInternal) -> Self {
let metrics = if req.metrics.is_empty() {
@@ -169,22 +161,14 @@ pub enum IntermediateBucketResult {
/// The buckets
buckets: Vec<IntermediateHistogramBucketEntry>,
},
}
impl From<SegmentBucketResultCollector> for IntermediateBucketResult {
fn from(collector: SegmentBucketResultCollector) -> Self {
match collector {
SegmentBucketResultCollector::Range(range) => range.into_intermediate_bucket_result(),
SegmentBucketResultCollector::Histogram(histogram) => {
histogram.into_intermediate_bucket_result()
}
}
}
/// Term aggregation
Terms(IntermediateTermBucketResult),
}
impl IntermediateBucketResult {
pub(crate) fn empty_from_req(req: &BucketAggregationType) -> Self {
match req {
BucketAggregationType::Terms(_) => IntermediateBucketResult::Terms(Default::default()),
BucketAggregationType::Range(_) => IntermediateBucketResult::Range(Default::default()),
BucketAggregationType::Histogram(_) => {
IntermediateBucketResult::Histogram { buckets: vec![] }
@@ -193,6 +177,16 @@ impl IntermediateBucketResult {
}
fn merge_fruits(&mut self, other: IntermediateBucketResult) {
match (self, other) {
(
IntermediateBucketResult::Terms(entries_left),
IntermediateBucketResult::Terms(entries_right),
) => {
merge_maps(&mut entries_left.entries, entries_right.entries);
entries_left.sum_other_doc_count += entries_right.sum_other_doc_count;
entries_left.doc_count_error_upper_bound +=
entries_right.doc_count_error_upper_bound;
}
(
IntermediateBucketResult::Range(entries_left),
IntermediateBucketResult::Range(entries_right),
@@ -232,6 +226,59 @@ impl IntermediateBucketResult {
(IntermediateBucketResult::Histogram { .. }, _) => {
panic!("try merge on different types")
}
(IntermediateBucketResult::Terms { .. }, _) => {
panic!("try merge on different types")
}
}
}
}
#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)]
/// Term aggregation including error counts
pub struct IntermediateTermBucketResult {
pub(crate) entries: FnvHashMap<String, IntermediateTermBucketEntry>,
pub(crate) sum_other_doc_count: u64,
pub(crate) doc_count_error_upper_bound: u64,
}
impl IntermediateTermBucketResult {
pub(crate) fn into_final_result(
self,
req: &TermsAggregation,
sub_aggregation_req: &AggregationsInternal,
) -> BucketResult {
let req = TermsAggregationInternal::from_req(req);
let mut buckets: Vec<BucketEntry> = self
.entries
.into_iter()
.map(|(key, entry)| BucketEntry {
key: Key::Str(key),
doc_count: entry.doc_count,
sub_aggregation: AggregationResults::from_intermediate_and_req_internal(
entry.sub_aggregation,
sub_aggregation_req,
),
})
.collect();
buckets.sort_unstable_by_key(|bucket| std::cmp::Reverse(bucket.doc_count));
// We ignore _term_doc_count_before_cutoff here, because it increases the upperbound error
// only for terms that didn't make it into the top N.
//
// This can be interesting, as a value of quality of the results, but not good to check the
// actual error count for the returned terms.
let (_term_doc_count_before_cutoff, sum_other_doc_count) =
cut_off_buckets(&mut buckets, req.size as usize);
let doc_count_error_upper_bound = if req.show_term_doc_count_error {
Some(self.doc_count_error_upper_bound)
} else {
None
};
BucketResult::Terms {
buckets,
sum_other_doc_count: self.sum_other_doc_count + sum_other_doc_count,
doc_count_error_upper_bound,
}
}
}
@@ -277,26 +324,6 @@ impl From<SegmentHistogramBucketEntry> for IntermediateHistogramBucketEntry {
}
}
impl
From<(
SegmentHistogramBucketEntry,
SegmentAggregationResultsCollector,
)> for IntermediateHistogramBucketEntry
{
fn from(
entry: (
SegmentHistogramBucketEntry,
SegmentAggregationResultsCollector,
),
) -> Self {
IntermediateHistogramBucketEntry {
key: entry.0.key,
doc_count: entry.0.doc_count,
sub_aggregation: entry.1.into(),
}
}
}
/// This is the range entry for a bucket, which contains a key, count, and optionally
/// sub_aggregations.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
@@ -305,7 +332,6 @@ pub struct IntermediateRangeBucketEntry {
pub key: Key,
/// The number of documents in the bucket.
pub doc_count: u64,
pub(crate) values: Option<Vec<u64>>,
/// The sub_aggregation in this bucket.
pub sub_aggregation: IntermediateAggregationResults,
/// The from range of the bucket. Equals f64::MIN when None.
@@ -316,22 +342,20 @@ pub struct IntermediateRangeBucketEntry {
pub to: Option<f64>,
}
impl From<SegmentRangeBucketEntry> for IntermediateRangeBucketEntry {
fn from(entry: SegmentRangeBucketEntry) -> Self {
let sub_aggregation = if let Some(sub_aggregation) = entry.sub_aggregation {
sub_aggregation.into()
} else {
Default::default()
};
/// This is the term entry for a bucket, which contains a count, and optionally
/// sub_aggregations.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct IntermediateTermBucketEntry {
/// The number of documents in the bucket.
pub doc_count: u64,
/// The sub_aggregation in this bucket.
pub sub_aggregation: IntermediateAggregationResults,
}
IntermediateRangeBucketEntry {
key: entry.key,
doc_count: entry.doc_count,
values: None,
sub_aggregation,
to: entry.to,
from: entry.from,
}
impl MergeFruits for IntermediateTermBucketEntry {
fn merge_fruits(&mut self, other: IntermediateTermBucketEntry) {
self.doc_count += other.doc_count;
self.sub_aggregation.merge_fruits(other.sub_aggregation);
}
}
@@ -366,7 +390,6 @@ mod tests {
IntermediateRangeBucketEntry {
key: Key::Str(key.to_string()),
doc_count: *doc_count,
values: None,
sub_aggregation: Default::default(),
from: None,
to: None,
@@ -394,7 +417,6 @@ mod tests {
IntermediateRangeBucketEntry {
key: Key::Str(key.to_string()),
doc_count: *doc_count,
values: None,
from: None,
to: None,
sub_aggregation: get_sub_test_tree(&[(

View File

@@ -318,7 +318,7 @@ mod tests {
use crate::aggregation::segment_agg_result::DOC_BLOCK_SIZE;
use crate::aggregation::DistributedAggregationCollector;
use crate::query::{AllQuery, TermQuery};
use crate::schema::{Cardinality, IndexRecordOption, Schema, TextFieldIndexing};
use crate::schema::{Cardinality, IndexRecordOption, Schema, TextFieldIndexing, FAST, STRING};
use crate::{Index, Term};
fn get_avg_req(field_name: &str) -> Aggregation {
@@ -337,17 +337,79 @@ mod tests {
)
}
pub fn exec_request(agg_req: Aggregations, index: &Index) -> crate::Result<Value> {
exec_request_with_query(agg_req, index, None)
}
pub fn exec_request_with_query(
agg_req: Aggregations,
index: &Index,
query: Option<(&str, &str)>,
) -> crate::Result<Value> {
let collector = AggregationCollector::from_aggs(agg_req);
let reader = index.reader()?;
let searcher = reader.searcher();
let agg_res = if let Some((field, term)) = query {
let text_field = reader.searcher().schema().get_field(field).unwrap();
let term_query = TermQuery::new(
Term::from_field_text(text_field, term),
IndexRecordOption::Basic,
);
searcher.search(&term_query, &collector)?
} else {
searcher.search(&AllQuery, &collector)?
};
let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?;
Ok(res)
}
pub fn get_test_index_from_values(
merge_segments: bool,
values: &[f64],
) -> crate::Result<Index> {
// Every value gets its own segment
let mut segment_and_values = vec![];
for value in values {
segment_and_values.push(vec![(*value, value.to_string())]);
}
get_test_index_from_values_and_terms(merge_segments, &segment_and_values)
}
pub fn get_test_index_from_terms(
merge_segments: bool,
values: &[Vec<&str>],
) -> crate::Result<Index> {
// Every value gets its own segment
let segment_and_values = values
.iter()
.map(|terms| {
terms
.iter()
.enumerate()
.map(|(i, term)| (i as f64, term.to_string()))
.collect()
})
.collect::<Vec<_>>();
get_test_index_from_values_and_terms(merge_segments, &segment_and_values)
}
pub fn get_test_index_from_values_and_terms(
merge_segments: bool,
segment_and_values: &[Vec<(f64, String)>],
) -> crate::Result<Index> {
let mut schema_builder = Schema::builder();
let text_fieldtype = crate::schema::TextOptions::default()
.set_indexing_options(
TextFieldIndexing::default().set_index_option(IndexRecordOption::WithFreqs),
)
.set_fast()
.set_stored();
let text_field = schema_builder.add_text_field("text", text_fieldtype);
let text_field = schema_builder.add_text_field("text", text_fieldtype.clone());
let text_field_id = schema_builder.add_text_field("text_id", text_fieldtype);
let string_field_id = schema_builder.add_text_field("string_id", STRING | FAST);
let score_fieldtype =
crate::schema::NumericOptions::default().set_fast(Cardinality::SingleValue);
let score_field = schema_builder.add_u64_field("score", score_fieldtype.clone());
@@ -360,15 +422,20 @@ mod tests {
let index = Index::create_in_ram(schema_builder.build());
{
let mut index_writer = index.writer_for_tests()?;
for &i in values {
// writing the segment
index_writer.add_document(doc!(
text_field => "cool",
score_field => i as u64,
score_field_f64 => i as f64,
score_field_i64 => i as i64,
fraction_field => i as f64/100.0,
))?;
for values in segment_and_values {
for (i, term) in values {
let i = *i;
// writing the segment
index_writer.add_document(doc!(
text_field => "cool",
text_field_id => term.to_string(),
string_field_id => term.to_string(),
score_field => i as u64,
score_field_f64 => i as f64,
score_field_i64 => i as i64,
fraction_field => i as f64/100.0,
))?;
}
index_writer.commit()?;
}
}
@@ -968,7 +1035,7 @@ mod tests {
let agg_res = avg_on_field("text");
assert_eq!(
format!("{:?}", agg_res),
r#"InvalidArgument("Only single value fast fields of type f64, u64, i64 are supported, but got Str ")"#
r#"InvalidArgument("Only fast fields of type f64, u64, i64 are supported, but got Str ")"#
);
let agg_res = avg_on_field("not_exist_field");
@@ -989,11 +1056,12 @@ mod tests {
#[cfg(all(test, feature = "unstable"))]
mod bench {
use rand::prelude::SliceRandom;
use rand::{thread_rng, Rng};
use test::{self, Bencher};
use super::*;
use crate::aggregation::bucket::{HistogramAggregation, HistogramBounds};
use crate::aggregation::bucket::{HistogramAggregation, HistogramBounds, TermsAggregation};
use crate::aggregation::metric::StatsAggregation;
use crate::query::AllQuery;
@@ -1005,6 +1073,10 @@ mod tests {
)
.set_stored();
let text_field = schema_builder.add_text_field("text", text_fieldtype);
let text_field_many_terms =
schema_builder.add_text_field("text_many_terms", STRING | FAST);
let text_field_few_terms =
schema_builder.add_text_field("text_few_terms", STRING | FAST);
let score_fieldtype =
crate::schema::NumericOptions::default().set_fast(Cardinality::SingleValue);
let score_field = schema_builder.add_u64_field("score", score_fieldtype.clone());
@@ -1012,6 +1084,7 @@ mod tests {
schema_builder.add_f64_field("score_f64", score_fieldtype.clone());
let score_field_i64 = schema_builder.add_i64_field("score_i64", score_fieldtype);
let index = Index::create_from_tempdir(schema_builder.build())?;
let few_terms_data = vec!["INFO", "ERROR", "WARN", "DEBUG"];
{
let mut rng = thread_rng();
let mut index_writer = index.writer_for_tests()?;
@@ -1020,6 +1093,8 @@ mod tests {
let val: f64 = rng.gen_range(0.0..1_000_000.0);
index_writer.add_document(doc!(
text_field => "cool",
text_field_many_terms => val.to_string(),
text_field_few_terms => few_terms_data.choose(&mut rng).unwrap().to_string(),
score_field => val as u64,
score_field_f64 => val as f64,
score_field_i64 => val as i64,
@@ -1171,6 +1246,64 @@ mod tests {
});
}
#[bench]
fn bench_aggregation_terms_few(b: &mut Bencher) {
let index = get_test_index_bench(false).unwrap();
let reader = index.reader().unwrap();
b.iter(|| {
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "text_few_terms".to_string(),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
)]
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req);
let searcher = reader.searcher();
let agg_res: AggregationResults =
searcher.search(&AllQuery, &collector).unwrap().into();
agg_res
});
}
#[bench]
fn bench_aggregation_terms_many(b: &mut Bencher) {
let index = get_test_index_bench(false).unwrap();
let reader = index.reader().unwrap();
b.iter(|| {
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "text_many_terms".to_string(),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
)]
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req);
let searcher = reader.searcher();
let agg_res: AggregationResults =
searcher.search(&AllQuery, &collector).unwrap().into();
agg_res
});
}
#[bench]
fn bench_aggregation_range_only(b: &mut Bencher) {
let index = get_test_index_bench(false).unwrap();

View File

@@ -9,11 +9,12 @@ use super::agg_req::MetricAggregation;
use super::agg_req_with_accessor::{
AggregationsWithAccessor, BucketAggregationWithAccessor, MetricAggregationWithAccessor,
};
use super::bucket::{SegmentHistogramCollector, SegmentRangeCollector};
use super::bucket::{SegmentHistogramCollector, SegmentRangeCollector, SegmentTermCollector};
use super::intermediate_agg_result::{IntermediateAggregationResults, IntermediateBucketResult};
use super::metric::{
AverageAggregation, SegmentAverageCollector, SegmentStatsCollector, StatsAggregation,
};
use super::{Key, VecWithNames};
use super::VecWithNames;
use crate::aggregation::agg_req::BucketAggregationType;
use crate::DocId;
@@ -40,6 +41,23 @@ impl Debug for SegmentAggregationResultsCollector {
}
impl SegmentAggregationResultsCollector {
pub fn into_intermediate_aggregations_result(
self,
agg_with_accessor: &AggregationsWithAccessor,
) -> IntermediateAggregationResults {
let buckets = self.buckets.map(|buckets| {
let entries = buckets
.into_iter()
.zip(agg_with_accessor.buckets.values())
.map(|((key, bucket), acc)| (key, bucket.into_intermediate_bucket_result(acc)))
.collect::<Vec<(String, _)>>();
VecWithNames::from_entries(entries)
});
let metrics = self.metrics.map(VecWithNames::from_other);
IntermediateAggregationResults { metrics, buckets }
}
pub(crate) fn from_req_and_validate(req: &AggregationsWithAccessor) -> crate::Result<Self> {
let buckets = req
.buckets
@@ -97,6 +115,9 @@ impl SegmentAggregationResultsCollector {
agg_with_accessor: &AggregationsWithAccessor,
force_flush: bool,
) {
if self.num_staged_docs == 0 {
return;
}
if let Some(metrics) = &mut self.metrics {
for (collector, agg_with_accessor) in
metrics.values_mut().zip(agg_with_accessor.metrics.values())
@@ -162,12 +183,38 @@ impl SegmentMetricResultCollector {
#[derive(Clone, Debug, PartialEq)]
pub(crate) enum SegmentBucketResultCollector {
Range(SegmentRangeCollector),
Histogram(SegmentHistogramCollector),
Histogram(Box<SegmentHistogramCollector>),
Terms(Box<SegmentTermCollector>),
}
impl SegmentBucketResultCollector {
pub fn into_intermediate_bucket_result(
self,
agg_with_accessor: &BucketAggregationWithAccessor,
) -> IntermediateBucketResult {
match self {
SegmentBucketResultCollector::Terms(terms) => {
terms.into_intermediate_bucket_result(agg_with_accessor)
}
SegmentBucketResultCollector::Range(range) => {
range.into_intermediate_bucket_result(agg_with_accessor)
}
SegmentBucketResultCollector::Histogram(histogram) => {
histogram.into_intermediate_bucket_result(agg_with_accessor)
}
}
}
pub fn from_req_and_validate(req: &BucketAggregationWithAccessor) -> crate::Result<Self> {
match &req.bucket_agg {
BucketAggregationType::Terms(terms_req) => Ok(Self::Terms(Box::new(
SegmentTermCollector::from_req_and_validate(
terms_req,
&req.sub_aggregation,
req.field_type,
req.accessor.as_multi(),
)?,
))),
BucketAggregationType::Range(range_req) => {
Ok(Self::Range(SegmentRangeCollector::from_req_and_validate(
range_req,
@@ -175,14 +222,14 @@ impl SegmentBucketResultCollector {
req.field_type,
)?))
}
BucketAggregationType::Histogram(histogram) => Ok(Self::Histogram(
BucketAggregationType::Histogram(histogram) => Ok(Self::Histogram(Box::new(
SegmentHistogramCollector::from_req_and_validate(
histogram,
&req.sub_aggregation,
req.field_type,
&req.accessor,
req.accessor.as_single(),
)?,
)),
))),
}
}
@@ -200,34 +247,9 @@ impl SegmentBucketResultCollector {
SegmentBucketResultCollector::Histogram(histogram) => {
histogram.collect_block(doc, bucket_with_accessor, force_flush)
}
SegmentBucketResultCollector::Terms(terms) => {
terms.collect_block(doc, bucket_with_accessor, force_flush)
}
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct SegmentHistogramBucketEntry {
pub key: f64,
pub doc_count: u64,
}
#[derive(Clone, PartialEq)]
pub(crate) struct SegmentRangeBucketEntry {
pub key: Key,
pub doc_count: u64,
pub sub_aggregation: Option<SegmentAggregationResultsCollector>,
/// The from range of the bucket. Equals f64::MIN when None.
pub from: Option<f64>,
/// The to range of the bucket. Equals f64::MAX when None.
pub to: Option<f64>,
}
impl Debug for SegmentRangeBucketEntry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SegmentRangeBucketEntry")
.field("key", &self.key)
.field("doc_count", &self.doc_count)
.field("from", &self.from)
.field("to", &self.to)
.finish()
}
}